diff --git a/.conda/bld.bat b/.conda/bld.bat index 3c9668685..157b97b14 100644 --- a/.conda/bld.bat +++ b/.conda/bld.bat @@ -13,7 +13,7 @@ rem # this out myself, ughhh. set PIP_NO_INDEX=False set PIP_NO_DEPENDENCIES=False set PIP_IGNORE_INSTALLED=False -pip install cattrs==1.0.0rc opencv-python==3.4.2.17 PySide2==5.12.0 imgaug qimage2ndarray==1.8 imgstore +pip install cattrs==1.0.0rc opencv-python-headless==3.4.1.15 PySide2==5.12.0 imgaug qimage2ndarray==1.8 imgstore rem # Use and update environment.yml call to install pip dependencies. This is slick. rem # While environment.yml contains the non pip dependencies, the only thing left diff --git a/.coveralls.yml b/.coveralls.yml new file mode 100644 index 000000000..d09470693 --- /dev/null +++ b/.coveralls.yml @@ -0,0 +1 @@ +service_name: appveyor \ No newline at end of file diff --git a/README.rst b/README.rst index 38dbdd609..d928ca360 100644 --- a/README.rst +++ b/README.rst @@ -10,7 +10,7 @@ .. |GitHub release| image:: https://img.shields.io/github/release/murthylab/sleap.js.svg :target: https://GitHub.com/murthylab/sleap/releases/ -Social LEAP Estimates Animal Pose (sLEAP) +Social LEAP Estimates Animal Pose (SLEAP) ========================================= .. image:: docs/_static/supp_mov1-long_clip.gif @@ -18,8 +18,8 @@ Social LEAP Estimates Animal Pose (sLEAP) | -**S**\ ocial **L**\ EAP **E**\ stimates **A**\ nimal **P**\ ose (**sLEAP**) is a framework for multi-animal -body part position estimation via deep learning. It is the successor to LEAP_. **sLEAP** is written entirely in +**S**\ ocial **L**\ EAP **E**\ stimates **A**\ nimal **P**\ ose (**SLEAP**) is a framework for multi-animal +body part position estimation via deep learning. It is the successor to LEAP_. **SLEAP** is written entirely in Python, supports multi-animal pose estimation, animal instance tracking, and a labeling/training GUI that supports active learning. @@ -30,19 +30,19 @@ supports active learning. Installation ------------ -**sLEAP** is compatible with Python versions 3.6 and above, with support for Windows and Linux. Mac OS X works but without GPU support. +**SLEAP** is compatible with Python versions 3.6 and above, with support for Windows and Linux. Mac OS X works but without GPU support. Windows ------- -Since **sLEAP** has a number of complex binary dependencies (TensorFlow, Keras, OpenCV), it is recommended to use the +Since **SLEAP** has a number of complex binary dependencies (TensorFlow, Keras, OpenCV), it is recommended to use the Anaconda_ Python distribution to simplify installation. Once Anaconda_ has been installed, go to start menu and type in *Anaconda*, which should bring up a menu entry **Anaconda Prompt** which opens a command line with the base anaconda environment activated. One of the key advantages to using `Anaconda Environments`_ is the ability to create separate Python installations (environments) for different projects, mitigating issues of managing complex dependencies. To create a new conda environment for -**sLEAP** related development and use: +**SLEAP** related development and use: :: @@ -59,7 +59,7 @@ Any Python installation commands (:code:`conda install` or :code:`pip install`) environment will only effect the environment. Thus it is important to make sure the environment is active when issuing any commands that deal with Python on the command line. -**sLEAP** is now installed in the :code:`sleap_env` conda environment. With the environment active, +**SLEAP** is now installed in the :code:`sleap_env` conda environment. With the environment active, you can run the labeling GUI by entering the following command: :: @@ -72,10 +72,10 @@ you can run the labeling GUI by entering the following command: Linux ----- -No Linux conda packages are currently provided by the **sLEAP** channel. However, installing via :code:`pip` should not +No Linux conda packages are currently provided by the **SLEAP** channel. However, installing via :code:`pip` should not be difficult on most Linux systems. The first step is to get a working version of TensorFlow installed in your Python environment. Follow official directions for installing TensorFlow_ with GPU support. Once TensorFlow is installed, simple -issue the following command to install **sLEAP** +issue the following command to install **SLEAP** .. _TensorFlow: https://www.tensorflow.org/install/gpu @@ -83,7 +83,7 @@ issue the following command to install **sLEAP** pip install git+https://github.com/murthylab/sleap.git -**sLEAP** is now installed you can run the labeling GUI by entering the following command: +**SLEAP** is now installed you can run the labeling GUI by entering the following command: :: @@ -93,7 +93,7 @@ Mac OS ------ The installation for Mac OS X is the same as for Linux, although there's no TensorFlow GPU support for Mac OS. -You can install TensorFlow and **sLEAP** together by running +You can install TensorFlow and **SLEAP** together by running :: @@ -102,6 +102,6 @@ You can install TensorFlow and **sLEAP** together by running Research -------- -If you use **sLEAP** in your research please acknowledge ... +If you use **SLEAP** in your research please acknowledge ... diff --git a/appveyor.yml b/appveyor.yml index ecbca628b..e466bc76c 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -3,11 +3,13 @@ version: '{build}' clone_depth: 5 environment: + COVERALLS_REPO_TOKEN: + secure: VsCyKmdi8x0OFK+Jbzk7ZRAW3EtojWP85TWqWKi+vuGmdiQFX7rLPnuaw3kt++a8 access_token: secure: T7XuBtHDu85Tk/d1AeyfhW3CVyzaoddTWmR4xsPIdQ3di0R6x8ncWqw3KrYXkWJm BUILD_DIR: "build" - + conda_access_token: secure: d+v++uejbVEhIuaJSuFIOA== matrix: @@ -56,14 +58,20 @@ install: # We need to install this separately, what a mess. # - pip install PySide2 opencv-python imgaug cattrs -# Install dev requirements too. + # Install dev requirements too. - pip install -r dev_requirements.txt + # Install sleap package + - pip install . + build: off test_script: - cmd: activate sleap_appveyor - cmd: where python - - cmd: python -m pytest tests/ + - cmd: pytest --cov=sleap tests/ + +on_success: + - cmd: coveralls # here we are going to override common configuration for: diff --git a/dev_requirements.txt b/dev_requirements.txt index f0e899b95..7b70af633 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -3,4 +3,5 @@ pytest-qt pytest-cov ipython sphinx -sphinx_rtd_theme \ No newline at end of file +sphinx_rtd_theme +coveralls \ No newline at end of file diff --git a/docs/Makefile b/docs/Makefile index 8f5aa6ab8..9f8237083 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -5,7 +5,7 @@ SPHINXOPTS = SPHINXBUILD = sphinx-build SOURCEDIR = . -BUILDDIR = ..\..\sleap-docs +BUILDDIR = ../../sleap-docs # Export the BUILDDIR so we can pick it up in conf.py. We need this to # be able to copy some the files in _static to an alternative location diff --git a/docs/conf.py b/docs/conf.py index fb170194e..78f497c50 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,7 +19,7 @@ # -- Project information ----------------------------------------------------- -project = 'LEAP' +project = 'SLEAP' copyright = '2019, Murthy Lab @ Princeton' author = 'Talmo D. Pereira, Nat Tabris, David M. Turner' @@ -105,7 +105,7 @@ # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. -htmlhelp_basename = 'sLEAPdoc' +htmlhelp_basename = 'SLEAPdoc' # -- Options for LaTeX output ------------------------------------------------ @@ -132,7 +132,7 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'sLEAP.tex', 'sLEAP Documentation', + (master_doc, 'SLEAP.tex', 'SLEAP Documentation', 'Talmo D. Pereira, Nat Tabris, David M. Turner', 'manual'), ] @@ -142,7 +142,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, 'sleap', 'sLEAP Documentation', + (master_doc, 'Sleap', 'SLEAP Documentation', [author], 1) ] @@ -153,8 +153,8 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'sLEAP', 'sLEAP Documentation', - author, 'sLEAP', 'One line description of project.', + (master_doc, 'SLEAP', 'SLEAP Documentation', + author, 'SLEAP', 'One line description of project.', 'Miscellaneous'), ] diff --git a/docs/gui.rst b/docs/gui.rst index a1d657797..76ee2b4ff 100644 --- a/docs/gui.rst +++ b/docs/gui.rst @@ -3,18 +3,100 @@ GUI .. automodule:: sleap.gui.app :members: + +Video Player +------------- .. automodule:: sleap.gui.video :members: + +Dialogs +------------- + +Active Learning +^^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.active + :members: + +Video Importer +^^^^^^^^^^^^^^ .. automodule:: sleap.gui.importvideos :members: -.. automodule:: sleap.gui.confmapsplot + +Merging +^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.merge :members: -.. automodule:: sleap.gui.quiverplot + +Shortcuts +^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.shortcuts :members: -.. automodule:: sleap.gui.dataviews + +Suggestions +^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.suggestions :members: -.. automodule:: sleap.gui.multicheck + +Training Profiles +^^^^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.training_editor + :members: + +Other Widgets +------------- + +Form builder +^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.formbuilder :members: + +Slider +^^^^^^^^^^^^^^ .. automodule:: sleap.gui.slider :members: +Multicheck +^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.multicheck + :members: + +Overlays +------------- + +Instances +^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.overlays.instance + :members: + +Tracks +^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.overlays.tracks + :members: + +Anchors +^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.overlays.anchors + :members: + +Datasource classes +^^^^^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.overlays.base + :members: + +Confidence maps +^^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.overlays.confmaps + :members: + + +Part affinity fields +^^^^^^^^^^^^^^^^^^^^ +.. automodule:: sleap.gui.overlays.pafs + :members: + + + +Dataviews +------------- +.. automodule:: sleap.gui.dataviews + :members: diff --git a/docs/index.rst b/docs/index.rst index bd16e505b..8fe194a91 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,7 +3,7 @@ .. _sleap: .. toctree:: - :caption: sLEAP Package + :caption: SLEAP Package :maxdepth: 3 tutorial @@ -14,6 +14,7 @@ training inference gui + misc .. _Indices_and_Tables: diff --git a/docs/misc.rst b/docs/misc.rst new file mode 100644 index 000000000..cea0d774f --- /dev/null +++ b/docs/misc.rst @@ -0,0 +1,35 @@ +Misc +======== + +Utils +------------- +.. automodule:: sleap.util + :members: + +Range list +------------- +.. automodule:: sleap.rangelist + :members: + +Legacy formats +-------------- +.. automodule:: sleap.io.legacy + :members: + +Info tools +---------- + +Metrics +^^^^^^^^^^^^^^ +.. automodule:: sleap.info.metrics + :members: + +Summary +^^^^^^^^^^^^^^ +.. automodule:: sleap.info.summary + :members: + +Track Analysis +^^^^^^^^^^^^^^ +.. automodule:: sleap.info.write_tracking_h5 + :members: diff --git a/docs/tutorial.rst b/docs/tutorial.rst index ad0fca3ca..3932f9e13 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -1,11 +1,11 @@ Tutorial ======== -Before you can use sLEAP, you’ll need to install it. Follow the -instructions at :ref:`Installation` to install sLEAP and +Before you can use SLEAP, you’ll need to install it. Follow the +instructions at :ref:`Installation` to install SLEAP and start the GUI app. -There are three main stages of using sLEAP: +There are three main stages of using SLEAP: 1. Creating a project, opening a movie and defining the skeleton; @@ -18,7 +18,7 @@ There are three main stages of using sLEAP: Stage 1: Creating a project --------------------------- -When you first start sLEAP you’ll see an open dialog. Since you don’t +When you first start SLEAP you’ll see an open dialog. Since you don’t yet have a project to open, click “Cancel” and you’ll be left with a new, empty project. @@ -32,7 +32,7 @@ on the right side of the main window, the “Add Video” command in the |image0| You’ll then be able to select one or more video files and click “Open”. -sLEAP currently supports mp4, avi, and h5 files. For mp4 and avi files, +SLEAP currently supports mp4, avi, and h5 files. For mp4 and avi files, you’ll be asked whether to import the video as grayscale. For h5 files, you’ll be asked the dataset and whether the video is stored with channels first or last. @@ -65,7 +65,7 @@ Stage 2: Labeling and learning We start by assembling a candidate group of images to label. You can either pick your own frames or let the system suggest a set of frames -using the “Generate Suggestions” panel. sLEAP can choose these frames +using the “Generate Suggestions” panel. SLEAP can choose these frames (i) randomly, or using (ii) Strides (evenly spaced samples), (iii) PCA (runs Principle Component Analysis on the images, clusters the images into groups, and uses sample frames from each cluster), or (iv) BRISK @@ -102,9 +102,9 @@ Saving ~~~~~~ Since this is a new project, you’ll need to select a location and name -the first time you save. sLEAP will ask you to save before closing any +the first time you save. SLEAP will ask you to save before closing any project that has been changed to avoid losing any work. Note: There is -not yet an “undo” feature built into sLEAP. If you want to make +not yet an “undo” feature built into SLEAP. If you want to make temporary changes to a project, use the “Save As…” command first to save a copy of your project. @@ -197,7 +197,7 @@ model doesn’t improve for a certain number of epochs (15 by default) First we train a model for confidence maps, part affinity fields, and centroids, and then we run inference. The GUI doesn’t yet give you a way to monitor the progress during inference, although you can get more -information in the console window from which you started sLEAP. +information in the console window from which you started SLEAP. When active learning finishes, you’ll be told how many instances were predicted. Suggested frames with predicted instances will be marked in @@ -265,7 +265,7 @@ Running inference remotely (optional) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ It’s also possible to run inference using the command line interface, which is -useful if you’re going to run on a cluster). The command to run inference on +useful if you’re going to run on a cluster. The command to run inference on an entire video is: :: @@ -276,7 +276,7 @@ an entire video is: -m path/to/models/your_paf_model.json \ -m path/to/models/your_centroid_model.json -The predictions will be saved in path/to/video.mp4.predictions.json.zip, +The predictions will be saved in path/to/video.mp4.predictions.h5, which you can open from the GUI app. You can also import these predictions into your project by opening your project and then using the "Import Predictions..." command in the "Predict" menu. diff --git a/environment.yml b/environment.yml index 7d929639a..779e21182 100644 --- a/environment.yml +++ b/environment.yml @@ -17,7 +17,7 @@ dependencies: - python-rapidjson - pip - pip: - - opencv-python==3.4.2.17 + - opencv-python-headless==3.4.1.15 - PySide2==5.12.0 - imgaug - cattrs==1.0.0rc0 diff --git a/requirements.txt b/requirements.txt index dd0c26892..21763ec20 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ tensorflow keras h5py python-rapidjson -opencv-python==3.4.2.17 +opencv-python-headless==3.4.1.15 pandas psutil PySide2 diff --git a/sleap/config/active.yaml b/sleap/config/active.yaml index 1c061da42..fa23b0b48 100644 --- a/sleap/config/active.yaml +++ b/sleap/config/active.yaml @@ -26,6 +26,11 @@ expert: type: bool default: False +- name: _dont_use_pafs + label: Single-instance mode (without pafs) + type: bool + default: False + - name: _view_paf label: View Edge Profile... type: button @@ -66,6 +71,7 @@ expert: label: Negative samples (if cropping) type: int default: 0 + range: 0,1000 - name: batch_size label: Batch Size @@ -110,6 +116,7 @@ learning: label: Negative samples type: int default: 20 + range: 0,1000 - name: batch_size label: Batch Size diff --git a/sleap/config/shortcuts.yaml b/sleap/config/shortcuts.yaml index cda364e6c..4e64205ed 100644 --- a/sleap/config/shortcuts.yaml +++ b/sleap/config/shortcuts.yaml @@ -1,32 +1,32 @@ -"new": QKeySequence.New -"open": QKeySequence.Open -"save": QKeySequence.Save -"save as": QKeySequence.SaveAs -"close": QKeySequence.Close -"add videos": Qt.CTRL + Qt.Key_A -"next video": QKeySequence.Forward -"prev video": QKeySequence.Back -"goto frame": Qt.CTRL + Qt.Key_J -"mark frame": Qt.CTRL + Qt.Key_M -"goto marked": Qt.CTRL + Qt.SHIFT + Qt.Key_M -"add instance": Qt.CTRL + Qt.Key_I -"delete instance": Qt.CTRL + Qt.Key_Backspace -"delete track": Qt.CTRL + Qt.SHIFT + Qt.Key_Backspace -"transpose": Qt.CTRL + Qt.Key_T -"select next": QKeySequence(Qt.Key.Key_QuoteLeft) -"clear selection": QKeySequence(Qt.Key.Key_Escape) -"goto next": Qt.CTRL + Qt.Key_Period -"goto prev": -"goto next user": Qt.CTRL + Qt.Key_Greater -"goto next suggestion": QKeySequence.FindNext -"goto prev suggestion": QKeySequence.FindPrevious -"goto next track": Qt.CTRL + Qt.Key_E -"show labels": Qt.CTRL + Qt.Key_Tab -"show edges": Qt.CTRL + Qt.SHIFT + Qt.Key_Tab -"show trails": -"color predicted": -"fit": Qt.CTRL + Qt.Key_Equal -"learning": Qt.CTRL + Qt.Key_L -"export clip": -"delete clip": -"delete area": Qt.CTRL + Qt.Key_K +add instance: Ctrl+I +add videos: Ctrl+A +clear selection: Esc +close: QKeySequence.Close +color predicted: +delete area: Ctrl+K +delete clip: +delete instance: Ctrl+Backspace +delete track: Ctrl+Shift+Backspace +export clip: +fit: Ctrl+= +goto frame: Ctrl+J +goto marked: Ctrl+Shift+M +goto next suggestion: QKeySequence.FindNext +goto next track spawn: Ctrl+E +goto next user: Ctrl+> +goto next labeled: Ctrl+. +goto prev suggestion: QKeySequence.FindPrevious +goto prev labeled: +learning: Ctrl+L +mark frame: Ctrl+M +new: Ctrl+N +next video: QKeySequence.Forward +open: Ctrl+O +prev video: QKeySequence.Back +save as: QKeySequence.SaveAs +save: Ctrl+S +select next: '`' +show edges: Ctrl+Shift+Tab +show labels: Ctrl+Tab +show trails: +transpose: Ctrl+T diff --git a/sleap/config/training_editor.yaml b/sleap/config/training_editor.yaml index 5ac01e6ad..c11f18a17 100644 --- a/sleap/config/training_editor.yaml +++ b/sleap/config/training_editor.yaml @@ -6,6 +6,13 @@ model: options: confmaps,pafs,centroids default: confmaps +- name: arch # backbone_name + label: Architecture + type: list + default: + options: LeapCNN,UNet,StackedHourglass + # options: LeapCNN,UNet,StackedHourglass,StackedUNet + - name: down_blocks label: Down Blocks type: int @@ -35,20 +42,35 @@ model: label: Upsampling Layers type: bool default: False - + - name: interp label: Interpolation type: list default: bilinear options: bilinear -# skeletons? +# stacked model: +- name: num_stacks + label: Stacks + type: int + default: 3 -- name: arch # backbone_name - label: Architecture - type: list - default: - options: LeapCNN,StackedHourglass,UNet,StackedUNet +# - name: batch_norm +# label: Batch norm +# type: bool +# default: True + +# - name: intermediate_inputs +# label: Intermediate inputs +# type: bool +# default: True + +# - name: initial_stride +# label: Initial stride +# type: int +# default: 1 + +# skeletons? datagen: diff --git a/sleap/gui/active.py b/sleap/gui/active.py index d47611f52..e86ce1bc2 100644 --- a/sleap/gui/active.py +++ b/sleap/gui/active.py @@ -1,11 +1,13 @@ +""" +Module for running active learning (or just inference) from GUI. +""" + import os import cattr -from datetime import datetime -import multiprocessing from functools import reduce from pkg_resources import Requirement, resource_filename -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from sleap.io.dataset import Labels from sleap.io.video import Video @@ -16,46 +18,71 @@ from PySide2 import QtWidgets, QtCore + class ActiveLearningDialog(QtWidgets.QDialog): + """Active learning dialog. + + The dialog can be used in different modes: + * simplified active learning (fewer controls) + * expert active learning (full controls) + * inference only + + Arguments: + labels_filename: Path to the dataset where we'll get training data. + labels: The dataset where we'll get training data and add predictions. + mode: String which specified mode ("active", "expert", or "inference"). + """ learningFinished = QtCore.Signal() - def __init__(self, - labels_filename: str, labels: Labels, - mode: str="expert", - only_predict: bool=False, - *args, **kwargs): + def __init__( + self, + labels_filename: str, + labels: Labels, + mode: str = "expert", + *args, + **kwargs, + ): super(ActiveLearningDialog, self).__init__(*args, **kwargs) self.labels_filename = labels_filename self.labels = labels self.mode = mode - self.only_predict = only_predict print(f"Number of frames to train on: {len(labels.user_labeled_frames)}") - title = dict(learning="Active Learning", - inference="Inference", - expert="Inference Pipeline", - ) + title = dict( + learning="Active Learning", + inference="Inference", + expert="Inference Pipeline", + ) - learning_yaml = resource_filename(Requirement.parse("sleap"),"sleap/config/active.yaml") + learning_yaml = resource_filename( + Requirement.parse("sleap"), "sleap/config/active.yaml" + ) self.form_widget = YamlFormWidget( - yaml_file=learning_yaml, - which_form=self.mode, - title=title[self.mode] + " Settings") + yaml_file=learning_yaml, + which_form=self.mode, + title=title[self.mode] + " Settings", + ) # form ui self.training_profile_widgets = dict() if "conf_job" in self.form_widget.fields: - self.training_profile_widgets[ModelOutputType.CONFIDENCE_MAP] = self.form_widget.fields["conf_job"] + self.training_profile_widgets[ + ModelOutputType.CONFIDENCE_MAP + ] = self.form_widget.fields["conf_job"] if "paf_job" in self.form_widget.fields: - self.training_profile_widgets[ModelOutputType.PART_AFFINITY_FIELD] = self.form_widget.fields["paf_job"] + self.training_profile_widgets[ + ModelOutputType.PART_AFFINITY_FIELD + ] = self.form_widget.fields["paf_job"] if "centroid_job" in self.form_widget.fields: - self.training_profile_widgets[ModelOutputType.CENTROIDS] = self.form_widget.fields["centroid_job"] + self.training_profile_widgets[ + ModelOutputType.CENTROIDS + ] = self.form_widget.fields["centroid_job"] self._rebuild_job_options() self._update_job_menus(init=True) @@ -63,8 +90,8 @@ def __init__(self, buttons = QtWidgets.QDialogButtonBox() self.cancel_button = buttons.addButton(QtWidgets.QDialogButtonBox.Cancel) self.run_button = buttons.addButton( - "Run "+title[self.mode], - QtWidgets.QDialogButtonBox.AcceptRole) + "Run " + title[self.mode], QtWidgets.QDialogButtonBox.AcceptRole + ) self.status_message = QtWidgets.QLabel("hi!") @@ -84,21 +111,29 @@ def __init__(self, # connect actions to buttons def edit_conf_profile(): - self.view_profile(self.form_widget["conf_job"], - model_type=ModelOutputType.CONFIDENCE_MAP) + self._view_profile( + self.form_widget["conf_job"], model_type=ModelOutputType.CONFIDENCE_MAP + ) + def edit_paf_profile(): - self.view_profile(self.form_widget["paf_job"], - model_type=ModelOutputType.PART_AFFINITY_FIELD) + self._view_profile( + self.form_widget["paf_job"], + model_type=ModelOutputType.PART_AFFINITY_FIELD, + ) + def edit_cent_profile(): - self.view_profile(self.form_widget["centroid_job"], - model_type=ModelOutputType.CENTROIDS) + self._view_profile( + self.form_widget["centroid_job"], model_type=ModelOutputType.CENTROIDS + ) if "_view_conf" in self.form_widget.buttons: self.form_widget.buttons["_view_conf"].clicked.connect(edit_conf_profile) if "_view_paf" in self.form_widget.buttons: self.form_widget.buttons["_view_paf"].clicked.connect(edit_paf_profile) if "_view_centoids" in self.form_widget.buttons: - self.form_widget.buttons["_view_centoids"].clicked.connect(edit_cent_profile) + self.form_widget.buttons["_view_centoids"].clicked.connect( + edit_cent_profile + ) if "_view_datagen" in self.form_widget.buttons: self.form_widget.buttons["_view_datagen"].clicked.connect(self.view_datagen) @@ -110,8 +145,13 @@ def edit_cent_profile(): self.update_gui() def _rebuild_job_options(self): + """ + Rebuilds list of profile options (checking for new profile files). + """ # load list of job profiles from directory - profile_dir = resource_filename(Requirement.parse("sleap"), "sleap/training_profiles") + profile_dir = resource_filename( + Requirement.parse("sleap"), "sleap/training_profiles" + ) labels_dir = os.path.join(os.path.dirname(self.labels_filename), "models") self.job_options = dict() @@ -122,33 +162,48 @@ def _rebuild_job_options(self): # list default profiles find_saved_jobs(profile_dir, self.job_options) - def _update_job_menus(self, init=False): + def _update_job_menus(self, init: bool = False): + """Updates the menus with training profile options. + + Args: + init: Whether this is first time calling (so we should connect + signals), or we're just updating menus. + + Returns: + None. + """ for model_type, field in self.training_profile_widgets.items(): if model_type not in self.job_options: self.job_options[model_type] = [] if init: - field.currentIndexChanged.connect(lambda idx, mt=model_type: self.select_job(mt, idx)) + field.currentIndexChanged.connect( + lambda idx, mt=model_type: self._update_from_selected_job(mt, idx) + ) else: # block signals so we can update combobox without overwriting # any user data with the defaults from the profile field.blockSignals(True) - field.set_options(self.option_list_from_jobs(model_type)) + field.set_options(self._option_list_from_jobs(model_type)) # enable signals again so that choice of profile will update params field.blockSignals(False) @property - def frame_selection(self): + def frame_selection(self) -> Dict[Video, List[int]]: + """ + Returns dictionary with frames that user has selected for inference. + """ return self._frame_selection @frame_selection.setter - def frame_selection(self, frame_selection): + def frame_selection(self, frame_selection: Dict[str, Dict[Video, List[int]]]): + """Sets options of frames on which to run inference.""" self._frame_selection = frame_selection if "_predict_frames" in self.form_widget.fields.keys(): prediction_options = [] def count_total_frames(videos_frames): - return reduce(lambda x,y:x+y, map(len, videos_frames.values())) + return reduce(lambda x, y: x + y, map(len, videos_frames.values())) # Determine which options are available given _frame_selection @@ -177,9 +232,12 @@ def count_total_frames(videos_frames): prediction_options.append(f"entire video ({video_length} frames)") - self.form_widget.fields["_predict_frames"].set_options(prediction_options, default_option) + self.form_widget.fields["_predict_frames"].set_options( + prediction_options, default_option + ) def show(self): + """Shows dialog (we hide rather than close to maintain settings).""" super(ActiveLearningDialog, self).show() # TODO: keep selection and any items added from training editor @@ -188,6 +246,7 @@ def show(self): self._update_job_menus() def update_gui(self): + """Updates gui state after user changes to options.""" form_data = self.form_widget.get_form_data() can_run = True @@ -211,38 +270,64 @@ def update_gui(self): self.form_widget.fields["instance_crop"].setEnabled(True) error_messages = [] - if form_data.get("_use_trained_confmaps", False) and \ - form_data.get("_use_trained_pafs", False): + if form_data.get("_use_trained_confmaps", False) and form_data.get( + "_use_trained_pafs", False + ): # make sure trained models are compatible conf_job, _ = self._get_current_job(ModelOutputType.CONFIDENCE_MAP) paf_job, _ = self._get_current_job(ModelOutputType.PART_AFFINITY_FIELD) - if conf_job.trainer.scale != paf_job.trainer.scale: - can_run = False - error_messages.append(f"training image scale for confmaps ({conf_job.trainer.scale}) does not match pafs ({paf_job.trainer.scale})") - if conf_job.trainer.instance_crop != paf_job.trainer.instance_crop: - can_run = False - crop_model_name = "confmaps" if conf_job.trainer.instance_crop else "pafs" - error_messages.append(f"exactly one model ({crop_model_name}) was trained on crops") - if use_centroids and not conf_job.trainer.instance_crop: - can_run = False - error_messages.append(f"models used with centroids must be trained on cropped images") + # only check compatible if we have both profiles + if conf_job is not None and paf_job is not None: + if conf_job.trainer.scale != paf_job.trainer.scale: + can_run = False + error_messages.append( + f"training image scale for confmaps ({conf_job.trainer.scale}) does not match pafs ({paf_job.trainer.scale})" + ) + if conf_job.trainer.instance_crop != paf_job.trainer.instance_crop: + can_run = False + crop_model_name = ( + "confmaps" if conf_job.trainer.instance_crop else "pafs" + ) + error_messages.append( + f"exactly one model ({crop_model_name}) was trained on crops" + ) + if use_centroids and not conf_job.trainer.instance_crop: + can_run = False + error_messages.append( + f"models used with centroids must be trained on cropped images" + ) message = "" if not can_run: - message = "Unable to run with selected models:\n- " + \ - ";\n- ".join(error_messages) + "." + message = ( + "Unable to run with selected models:\n- " + + ";\n- ".join(error_messages) + + "." + ) self.status_message.setText(message) self.run_button.setEnabled(can_run) - def _get_current_job(self, model_type): + def _get_current_job(self, model_type: ModelOutputType) -> Tuple[TrainingJob, str]: + """Returns training job currently selected for given model type. + + Args: + model_type: The type of model for which we want data. + + Returns: Tuple of (TrainingJob, path to job profile). + """ # by default use the first model for a given type idx = 0 if model_type in self.training_profile_widgets: field = self.training_profile_widgets[model_type] idx = field.currentIndex() + # Check that selection corresponds to something we're loaded + # (it won't when user is adding a new profile) + if idx >= len(self.job_options[model_type]): + return None, None + job_filename, job = self.job_options[model_type][idx] if model_type == ModelOutputType.CENTROIDS: @@ -253,11 +338,16 @@ def _get_current_job(self, model_type): return job, job_filename def _get_model_types_to_use(self): + """Returns lists of model types which user has enabled.""" form_data = self.form_widget.get_form_data() types_to_use = [] + # always include confidence maps types_to_use.append(ModelOutputType.CONFIDENCE_MAP) - types_to_use.append(ModelOutputType.PART_AFFINITY_FIELD) + + # by default we want to use part affinity fields + if not form_data.get("_dont_use_pafs", False): + types_to_use.append(ModelOutputType.PART_AFFINITY_FIELD) # by default we want to use centroids if form_data.get("_use_centroids", True): @@ -265,11 +355,12 @@ def _get_model_types_to_use(self): return types_to_use - def _get_current_training_jobs(self): + def _get_current_training_jobs(self) -> Dict[ModelOutputType, TrainingJob]: + """Returns all currently selected training jobs.""" form_data = self.form_widget.get_form_data() training_jobs = dict() - default_use_trained = (self.mode == "inference") + default_use_trained = self.mode == "inference" for model_type in self._get_model_types_to_use(): job, _ = self._get_current_job(model_type) @@ -293,6 +384,7 @@ def _get_current_training_jobs(self): return training_jobs def run(self): + """Run active learning (or inference) with current dialog settings.""" # Collect TrainingJobs and params from form form_data = self.form_widget.get_form_data() training_jobs = self._get_current_training_jobs() @@ -316,43 +408,36 @@ def run(self): with_tracking = True else: frames_to_predict = dict() - save_confmaps_pafs = form_data.get("_save_confmaps_pafs", False) + save_confmaps_pafs = False + # Disable save_confmaps_pafs since not currently working. + # The problem is that we can't put data for different crop sizes + # all into a single h5 datasource. It's now possible to view live + # predicted confmap and paf in the gui, so this isn't high priority. + # If you want to enable, uncomment this: + # save_confmaps_pafs = form_data.get("_save_confmaps_pafs", False) # Run active learning pipeline using the TrainingJobs - new_lfs = run_active_learning_pipeline( - labels_filename = self.labels_filename, - labels = self.labels, - training_jobs = training_jobs, - frames_to_predict = frames_to_predict, - with_tracking = with_tracking, - save_confmaps_pafs = save_confmaps_pafs) - - # remove labeledframes without any predicted instances - new_lfs = list(filter(lambda lf: len(lf.instances), new_lfs)) - # Update labels with results of active learning - - new_tracks = {inst.track for lf in new_lfs for inst in lf.instances if inst.track is not None} - if len(new_tracks) < 50: - self.labels.tracks = list(set(self.labels.tracks).union(new_tracks)) - # if there are more than 50 predicted tracks, assume this is wrong (FIXME?) - elif len(new_tracks): - for lf in new_lfs: - for inst in lf.instances: - inst.track = None - - # Update Labels with new data - # add new labeled frames - self.labels.extend_from(new_lfs) - # combine instances from labeledframes with same video/frame_idx - self.labels.merge_matching_frames() + new_counts = run_active_learning_pipeline( + labels_filename=self.labels_filename, + labels=self.labels, + training_jobs=training_jobs, + frames_to_predict=frames_to_predict, + with_tracking=with_tracking, + ) self.learningFinished.emit() - QtWidgets.QMessageBox(text=f"Active learning has finished. Instances were predicted on {len(new_lfs)} frames.").exec_() + QtWidgets.QMessageBox( + text=f"Active learning has finished. Instances were predicted on {new_counts} frames." + ).exec_() def view_datagen(self): - from sleap.nn.datagen import generate_training_data, \ - generate_confmaps_from_points, generate_pafs_from_points + """Shows windows with sample visual data that will be used training.""" + from sleap.nn.datagen import ( + generate_training_data, + generate_confmaps_from_points, + generate_pafs_from_points, + ) from sleap.io.video import Video from sleap.gui.overlays.confmaps import demo_confmaps from sleap.gui.overlays.pafs import demo_pafs @@ -370,19 +455,23 @@ def view_datagen(self): negative_samples = form_data.get("negative_samples", 0) imgs, points = generate_training_data( - self.labels, - params = dict( - frame_limit = 10, - scale = scale, - instance_crop = instance_crop, - min_crop_size = min_crop_size, - negative_samples = negative_samples)) + self.labels, + params=dict( + frame_limit=10, + scale=scale, + instance_crop=instance_crop, + min_crop_size=min_crop_size, + negative_samples=negative_samples, + ), + ) skeleton = self.labels.skeletons[0] img_shape = (imgs.shape[1], imgs.shape[2]) vid = Video.from_numpy(imgs * 255) - confmaps = generate_confmaps_from_points(points, skeleton, img_shape, sigma=sigma_confmaps) + confmaps = generate_confmaps_from_points( + points, skeleton, img_shape, sigma=sigma_confmaps + ) conf_win = demo_confmaps(confmaps, vid) conf_win.activateWindow() conf_win.move(200, 200) @@ -390,14 +479,14 @@ def view_datagen(self): pafs = generate_pafs_from_points(points, skeleton, img_shape, sigma=sigma_pafs) paf_win = demo_pafs(pafs, vid) paf_win.activateWindow() - paf_win.move(220+conf_win.rect().width(), 200) + paf_win.move(220 + conf_win.rect().width(), 200) # FIXME: hide dialog so use can see other windows # can we show these windows without closing dialog? self.hide() - # open profile editor in new dialog window - def view_profile(self, filename, model_type, windows=[]): + def _view_profile(self, filename: str, model_type: ModelOutputType, windows=[]): + """Opens profile editor in new dialog window.""" saved_files = [] win = TrainingEditor(filename, saved_files=saved_files, parent=self) windows.append(win) @@ -406,33 +495,40 @@ def view_profile(self, filename, model_type, windows=[]): for new_filename in saved_files: self._add_job_file_to_list(new_filename, model_type) - def option_list_from_jobs(self, model_type): + def _option_list_from_jobs(self, model_type: ModelOutputType): + """Returns list of menu options for given model type.""" jobs = self.job_options[model_type] option_list = [name for (name, job) in jobs] option_list.append("---") option_list.append("Select a training profile file...") return option_list - def add_job_file(self, model_type): + def _add_job_file(self, model_type): + """Allow user to add training profile for given model type.""" filename, _ = QtWidgets.QFileDialog.getOpenFileName( - None, dir=None, - caption="Select training profile...", - filter="TrainingJob JSON (*.json)") + None, + dir=None, + caption="Select training profile...", + filter="TrainingJob JSON (*.json)", + ) self._add_job_file_to_list(filename, model_type) field = self.training_profile_widgets[model_type] # if we didn't successfully select a new file, then clear selection - if field.currentIndex() == field.count()-1: # subtract 1 for separator + if field.currentIndex() == field.count() - 1: # subtract 1 for separator field.setCurrentIndex(-1) - def _add_job_file_to_list(self, filename, model_type): + def _add_job_file_to_list(self, filename: str, model_type: ModelOutputType): + """Adds selected training profile for given model type.""" if len(filename): try: # try to load json as TrainingJob job = TrainingJob.load_json(filename) except: # but do raise any other type of error - QtWidgets.QMessageBox(text=f"Unable to load a training profile from {filename}.").exec_() + QtWidgets.QMessageBox( + text=f"Unable to load a training profile from {filename}." + ).exec_() raise else: # we loaded the json as a TrainingJob, so see what type of model it's for @@ -444,18 +540,26 @@ def _add_job_file_to_list(self, filename, model_type): # update ui list if model_type in self.training_profile_widgets: field = self.training_profile_widgets[model_type] - field.set_options(self.option_list_from_jobs(model_type), filename) + field.set_options( + self._option_list_from_jobs(model_type), filename + ) else: - QtWidgets.QMessageBox(text=f"Profile selected is for training {str(file_model_type)} instead of {str(model_type)}.").exec_() + QtWidgets.QMessageBox( + text=f"Profile selected is for training {str(file_model_type)} instead of {str(model_type)}." + ).exec_() - def select_job(self, model_type, idx): + def _update_from_selected_job(self, model_type: ModelOutputType, idx: int): + """Updates dialog settings after user selects a training profile.""" jobs = self.job_options[model_type] - if idx == -1: return + if idx == -1: + return if idx < len(jobs): name, job = jobs[idx] training_params = cattr.unstructure(job.trainer) - training_params_specific = {f"{key}_{str(model_type)}":val for key,val in training_params.items()} + training_params_specific = { + f"{key}_{str(model_type)}": val for key, val in training_params.items() + } # confmap and paf models should share some params shown in dialog (e.g. scale) # but centroids does not, so just set any centroid_foo fields from its profile if model_type in [ModelOutputType.CENTROIDS]: @@ -476,57 +580,53 @@ def select_job(self, model_type, idx): self.form_widget[field_name] = has_trained else: # last item is "select file..." - self.add_job_file(model_type) + self._add_job_file(model_type) -def make_default_training_jobs(): - from sleap.nn.model import Model, ModelOutputType - from sleap.nn.training import Trainer, TrainingJob +def make_default_training_jobs() -> Dict[ModelOutputType, TrainingJob]: + """Creates TrainingJobs with some default settings.""" + from sleap.nn.model import Model + from sleap.nn.training import Trainer from sleap.nn.architectures import unet, leap # Build Models (wrapper for Keras model with some metadata) models = dict() models[ModelOutputType.CONFIDENCE_MAP] = Model( - output_type=ModelOutputType.CONFIDENCE_MAP, - backbone=unet.UNet(num_filters=32)) + output_type=ModelOutputType.CONFIDENCE_MAP, backbone=unet.UNet(num_filters=32) + ) models[ModelOutputType.PART_AFFINITY_FIELD] = Model( - output_type=ModelOutputType.PART_AFFINITY_FIELD, - backbone=leap.LeapCNN(num_filters=64)) + output_type=ModelOutputType.PART_AFFINITY_FIELD, + backbone=leap.LeapCNN(num_filters=64), + ) # Build Trainers defaults = dict() defaults["shared"] = dict( - instance_crop = True, - val_size = 0.1, - augment_rotation=180, - batch_size=4, - learning_rate = 1e-4, - reduce_lr_factor=0.5, - reduce_lr_cooldown=3, - reduce_lr_min_delta=1e-6, - reduce_lr_min_lr = 1e-10, - amsgrad = True, - shuffle_every_epoch=True, - save_every_epoch = False, -# val_batches_per_epoch = 10, -# upsampling_layers = True, -# depth = 3, + instance_crop=True, + val_size=0.1, + augment_rotation=180, + batch_size=4, + learning_rate=1e-4, + reduce_lr_factor=0.5, + reduce_lr_cooldown=3, + reduce_lr_min_delta=1e-6, + reduce_lr_min_lr=1e-10, + amsgrad=True, + shuffle_every_epoch=True, + save_every_epoch=False, + # val_batches_per_epoch = 10, + # upsampling_layers = True, + # depth = 3, ) defaults[ModelOutputType.CONFIDENCE_MAP] = dict( - **defaults["shared"], - num_epochs=100, - steps_per_epoch=200, - reduce_lr_patience=5, - ) + **defaults["shared"], num_epochs=100, steps_per_epoch=200, reduce_lr_patience=5 + ) defaults[ModelOutputType.PART_AFFINITY_FIELD] = dict( - **defaults["shared"], - num_epochs=75, - steps_per_epoch = 100, - reduce_lr_patience=8, - ) + **defaults["shared"], num_epochs=75, steps_per_epoch=100, reduce_lr_patience=8 + ) trainers = dict() for type in models.keys(): @@ -540,16 +640,19 @@ def make_default_training_jobs(): return training_jobs -def find_saved_jobs(job_dir, jobs=None): + +def find_saved_jobs( + job_dir: str, jobs=None +) -> Dict[ModelOutputType, List[Tuple[str, TrainingJob]]]: """Find all the TrainingJob json files in a given directory. Args: job_dir: the directory in which to look for json files - jobs (optional): append to jobs, rather than creating new dict + jobs: If given, then the found jobs will be added to this object, + rather than creating new dict. Returns: dict of {ModelOutputType: list of (filename, TrainingJob) tuples} """ - from sleap.nn.training import TrainingJob files = os.listdir(job_dir) @@ -578,23 +681,50 @@ def find_saved_jobs(job_dir, jobs=None): return jobs + +def add_frames_from_json(labels: Labels, new_labels_json: str) -> int: + """Merges new predictions (given as json string) into dataset. + + Args: + labels: The dataset to which we're adding the predictions. + new_labels_json: A JSON string which can be deserialized into `Labels`. + Returns: + Number of labeled frames with new predictions. + """ + # Deserialize the new frames, matching to the existing videos/skeletons if possible + new_lfs = Labels.from_json(new_labels_json, match_to=labels).labeled_frames + + # Remove any frames without instances + new_lfs = list(filter(lambda lf: len(lf.instances), new_lfs)) + + # Now add them to labels and merge labeled frames with same video/frame_idx + labels.extend_from(new_lfs) + labels.merge_matching_frames() + + return len(new_lfs) + + def run_active_learning_pipeline( - labels_filename: str, - labels: Labels=None, - training_jobs: Dict=None, - frames_to_predict: Dict=None, - with_tracking: bool=False, - save_confmaps_pafs: bool=False, - skip_learning: bool=False): - # Imports here so we don't load TensorFlow before necessary - from sleap.nn.monitor import LossViewer - from sleap.nn.training import TrainingJob - from sleap.nn.model import ModelOutputType - from sleap.nn.inference import Predictor + labels_filename: str, + labels: Labels, + training_jobs: Dict["ModelOutputType", "TrainingJob"] = None, + frames_to_predict: Dict[Video, List[int]] = None, + with_tracking: bool = False, +) -> int: + """Run training (as needed) and inference. - from PySide2 import QtWidgets + Args: + labels_filename: Path to already saved current labels object. + labels: The current labels object; results will be added to this. + training_jobs: The TrainingJobs with params/hyperparams for training. + frames_to_predict: Dict that gives list of frame indices for each video. + with_tracking: Whether to run tracking code after we predict instances. + This should be used only when predicting on continuous set of frames. + + Returns: + Number of new frames added to labels. - labels = labels or Labels.load_json(labels_filename) + """ # Prepare our TrainingJobs @@ -605,124 +735,218 @@ def run_active_learning_pipeline( # Set the parameters specific to this run for job in training_jobs.values(): job.labels_filename = labels_filename -# job.trainer.scale = scale - # Run the TrainingJobs + save_dir = os.path.join(os.path.dirname(labels_filename), "models") - save_dir = os.path.join(os.path.dirname(labels_filename), "models") + # Train the TrainingJobs + trained_jobs = run_active_training(labels, training_jobs, save_dir) - # open training monitor window - win = LossViewer() - win.resize(600, 400) - win.show() + # Check that all the models were trained + if None in trained_jobs.values(): + return 0 + + # Run the Predictor for suggested frames + new_labeled_frame_count = run_active_inference( + labels, trained_jobs, save_dir, frames_to_predict, with_tracking + ) + + return new_labeled_frame_count + + +def run_active_training( + labels: Labels, + training_jobs: Dict["ModelOutputType", "TrainingJob"], + save_dir: str, + gui: bool = True, +) -> Dict["ModelOutputType", "TrainingJob"]: + """ + Run training for each training job. + + Args: + labels: Labels object from which we'll get training data. + training_jobs: Dict of the jobs to train. + save_dir: Path to the directory where we'll save inference results. + gui: Whether to show gui windows and process gui events. + + Returns: + Dict of trained jobs corresponding with input training jobs. + """ + + trained_jobs = dict() + + if gui: + from sleap.nn.monitor import LossViewer + + # open training monitor window + win = LossViewer() + win.resize(600, 400) + win.show() for model_type, job in training_jobs.items(): if getattr(job, "use_trained_model", False): # set path to TrainingJob already trained from previous run json_name = f"{job.run_name}.json" - training_jobs[model_type] = os.path.join(job.save_dir, json_name) - print(f"Using already trained model: {training_jobs[model_type]}") + trained_jobs[model_type] = os.path.join(job.save_dir, json_name) + print(f"Using already trained model: {trained_jobs[model_type]}") else: - print("Resetting monitor window.") - win.reset(what=str(model_type)) - win.setWindowTitle(f"Training Model - {str(model_type)}") + if gui: + print("Resetting monitor window.") + win.reset(what=str(model_type)) + win.setWindowTitle(f"Training Model - {str(model_type)}") + print(f"Start training {str(model_type)}...") - if not skip_learning: - # run training - pool, result = job.trainer.train_async(model=job.model, labels=labels, - save_dir=save_dir) + # Start training in separate process + # This makes it easier to ensure that tensorflow released memory when done + pool, result = job.trainer.train_async( + model=job.model, labels=labels, save_dir=save_dir + ) - while not result.ready(): + # Wait for training results + while not result.ready(): + if gui: QtWidgets.QApplication.instance().processEvents() - # win.check_messages() - result.wait(.01) + result.wait(0.01) - if result.successful(): - # get the path to the resulting TrainingJob file - training_jobs[model_type] = result.get() - print(f"Finished training {str(model_type)}.") - else: - training_jobs[model_type] = None + if result.successful(): + # get the path to the resulting TrainingJob file + trained_jobs[model_type] = result.get() + print(f"Finished training {str(model_type)}.") + else: + if gui: win.close() - QtWidgets.QMessageBox(text=f"An error occured while training {str(model_type)}. Your command line terminal may have more information about the error.").exec_() - result.get() - - - if not skip_learning: - for model_type, job in training_jobs.items(): - # load job from json - training_jobs[model_type] = TrainingJob.load_json(training_jobs[model_type]) + QtWidgets.QMessageBox( + text=f"An error occured while training {str(model_type)}. Your command line terminal may have more information about the error." + ).exec_() + trained_jobs[model_type] = None + result.get() + + # Load the jobs we just trained + for model_type, job in trained_jobs.items(): + # Replace path to saved TrainingJob with the deseralized object + if trained_jobs[model_type] is not None: + trained_jobs[model_type] = TrainingJob.load_json(trained_jobs[model_type]) + + if gui: + # close training monitor window + win.close() + + return trained_jobs + + +def run_active_inference( + labels: Labels, + training_jobs: Dict["ModelOutputType", "TrainingJob"], + save_dir: str, + frames_to_predict: Dict[Video, List[int]], + with_tracking: bool, + gui: bool = True, +) -> int: + """Run inference on specified frames using models from training_jobs. - # close training monitor window - win.close() - - if not skip_learning: - timestamp = datetime.now().strftime("%y%m%d_%H%M%S") - inference_output_path = os.path.join(save_dir, f"{timestamp}.inference.h5") + Args: + labels: The current labels object; results will be added to this. + training_jobs: The TrainingJobs with trained models to use. + save_dir: Path to the directory where we'll save inference results. + frames_to_predict: Dict that gives list of frame indices for each video. + with_tracking: Whether to run tracking code after we predict instances. + This should be used only when predicting on continuous set of frames. + gui: Whether to show gui windows and process gui events. - # Create Predictor from the results of training - predictor = Predictor(sleap_models=training_jobs, - with_tracking=with_tracking, - output_path=inference_output_path, - save_confmaps_pafs=save_confmaps_pafs) + Returns: + Number of new frames added to labels. + """ + from sleap.nn.inference import Predictor - # Run the Predictor for suggested frames - # We want to predict for suggested frames that don't already have user instances + # from multiprocessing import Pool - new_labeled_frames = [] - user_labeled_frames = labels.user_labeled_frames + # total_new_lf_count = 0 + # timestamp = datetime.now().strftime("%y%m%d_%H%M%S") + # inference_output_path = os.path.join(save_dir, f"{timestamp}.inference.h5") - # show message while running inference - win = QtWidgets.QProgressDialog() - win.setLabelText(" Running inference on selected frames... ") - win.show() - QtWidgets.QApplication.instance().processEvents() + # Create Predictor from the results of training + # pool = Pool(processes=1) + predictor = Predictor( + training_jobs=training_jobs, + with_tracking=with_tracking, + # output_path=inference_output_path, + # pool=pool + ) - for video, frames in frames_to_predict.items(): + if gui: + # show message while running inference + progress = QtWidgets.QProgressDialog( + f"Running inference on {len(frames_to_predict)} videos...", + "Cancel", + 0, + len(frames_to_predict), + ) + # win.setLabelText(" Running inference on selected frames... ") + progress.show() + QtWidgets.QApplication.instance().processEvents() + + new_lfs = [] + for i, (video, frames) in enumerate(frames_to_predict.items()): + QtWidgets.QApplication.instance().processEvents() if len(frames): - if not skip_learning: - # run predictions for desired frames in this video - # video_lfs = predictor.predict(input_video=video, frames=frames, output_path=inference_output_path) - - pool, result = predictor.predict_async( - input_video=video, - frames=frames) - - while not result.ready(): - QtWidgets.QApplication.instance().processEvents() - result.wait(.01) - - if result.successful(): - new_labels_json = result.get() - new_labels = Labels.from_json(new_labels_json, match_to=labels) - - video_lfs = new_labels.labeled_frames - else: - QtWidgets.QMessageBox(text=f"An error occured during inference. Your command line terminal may have more information about the error.").exec_() - result.get() - else: - import time - time.sleep(1) - video_lfs = [] - - new_labeled_frames.extend(video_lfs) + # Run inference for desired frames in this video + # result = predictor.predict_async( + new_lfs_video = predictor.predict(input_video=video, frames=frames) + new_lfs.extend(new_lfs_video) + + if gui: + progress.setValue(i) + if progress.wasCanceled(): + return 0 + + # while not result.ready(): + # if gui: + # QtWidgets.QApplication.instance().processEvents() + # result.wait(.01) + + # if result.successful(): + # new_labels_json = result.get() + + # Add new frames to labels + # (we're doing this for each video as we go since there was a problem + # when we tried to add frames for all videos together.) + # new_lf_count = add_frames_from_json(labels, new_labels_json) + + # total_new_lf_count += new_lf_count + # else: + # if gui: + # QtWidgets.QApplication.instance().processEvents() + # QtWidgets.QMessageBox(text=f"An error occured during inference. Your command line terminal may have more information about the error.").exec_() + # result.get() + + # predictor.pool.close() + + # Remove any frames without instances + new_lfs = list(filter(lambda lf: len(lf.instances), new_lfs)) + + # Now add them to labels and merge labeled frames with same video/frame_idx + # labels.extend_from(new_lfs) + labels.extend_from(new_lfs, unify=True) + labels.merge_matching_frames() # close message window - win.close() + if gui: + progress.close() + + # return total_new_lf_count + return len(new_lfs) - return new_labeled_frames if __name__ == "__main__": import sys -# labels_filename = "/Volumes/fileset-mmurthy/nat/shruthi/labels-mac.json" + # labels_filename = "/Volumes/fileset-mmurthy/nat/shruthi/labels-mac.json" labels_filename = sys.argv[1] labels = Labels.load_json(labels_filename) app = QtWidgets.QApplication() - win = ActiveLearningDialog(labels=labels,labels_filename=labels_filename) + win = ActiveLearningDialog(labels=labels, labels_filename=labels_filename) win.show() app.exec_() diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 19e3e725d..97e8911fa 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -1,57 +1,99 @@ -from PySide2 import QtCore, QtWidgets -from PySide2.QtCore import Qt +""" +Main GUI application for labeling, active learning, and proofreading. +""" -from PySide2.QtGui import QKeyEvent, QKeySequence +from PySide2 import QtCore, QtWidgets +from PySide2.QtCore import Qt, QEvent from PySide2.QtWidgets import QApplication, QMainWindow, QWidget, QDockWidget -from PySide2.QtWidgets import QVBoxLayout, QHBoxLayout, QGroupBox, QFormLayout -from PySide2.QtWidgets import QLabel, QPushButton, QLineEdit, QSpinBox, QDoubleSpinBox, QComboBox, QCheckBox -from PySide2.QtWidgets import QTableWidget, QTableView, QTableWidgetItem -from PySide2.QtWidgets import QMenu, QAction +from PySide2.QtWidgets import QVBoxLayout, QHBoxLayout, QGroupBox +from PySide2.QtWidgets import QLabel, QPushButton, QComboBox from PySide2.QtWidgets import QFileDialog, QMessageBox -import copy +import re import operator import os import sys -import yaml from pkg_resources import Requirement, resource_filename from pathlib import PurePath +from typing import Callable, Dict, Iterator, Optional + import numpy as np -import pandas as pd -from sleap.skeleton import Skeleton, Node -from sleap.instance import Instance, PredictedInstance, Point, LabeledFrame, Track +from sleap.skeleton import Skeleton +from sleap.instance import Instance, PredictedInstance, Point, Track from sleap.io.video import Video from sleap.io.dataset import Labels +from sleap.info.summary import StatisticSeries from sleap.gui.video import QtVideoPlayer -from sleap.gui.dataviews import VideosTable, SkeletonNodesTable, SkeletonEdgesTable, \ - LabeledFrameTable, SkeletonNodeModel, SuggestionsTable +from sleap.gui.dataviews import ( + VideosTable, + SkeletonNodesTable, + SkeletonEdgesTable, + LabeledFrameTable, + SkeletonNodeModel, + SuggestionsTable, +) from sleap.gui.importvideos import ImportVideos from sleap.gui.formbuilder import YamlFormWidget +from sleap.gui.merge import MergeDialog +from sleap.gui.shortcuts import Shortcuts, ShortcutDialog from sleap.gui.suggestions import VideoFrameSuggestions -from sleap.gui.overlays.tracks import TrackColorManager, TrackTrailOverlay +from sleap.gui.overlays.tracks import ( + TrackColorManager, + TrackTrailOverlay, + TrackListOverlay, +) from sleap.gui.overlays.instance import InstanceOverlay +from sleap.gui.overlays.anchors import NegativeAnchorOverlay OPEN_IN_NEW = True + class MainWindow(QMainWindow): + """The SLEAP GUI application. + + Each project (`Labels` dataset) that you have loaded in the GUI will + have its own `MainWindow` object. + + Attributes: + labels: The :class:`Labels` dataset. If None, a new, empty project + (i.e., :class:`Labels` object) will be created. + skeleton: The active :class:`Skeleton` for the project in the gui + video: The active :class:`Video` in view in the gui + """ + labels: Labels skeleton: Skeleton video: Video - def __init__(self, data_path=None, video=None, import_data=None, *args, **kwargs): + def __init__( + self, + labels_path: Optional[str] = None, + nonnative: bool = False, + *args, + **kwargs, + ): + """Initialize the app. + + Args: + labels_path: Path to saved :class:`Labels` dataset. + nonnative: Whether to use native or Qt file dialog. + Returns: + None. + """ super(MainWindow, self).__init__(*args, **kwargs) + self.nonnative = nonnative + self.labels = Labels() self.skeleton = Skeleton() self.labeled_frame = None self.video = None self.video_idx = None - self.mark_idx = None self.filename = None self._menu_actions = dict() self._buttons = dict() @@ -67,206 +109,381 @@ def __init__(self, data_path=None, video=None, import_data=None, *args, **kwargs self._auto_zoom = False self.changestack_clear() - self.initialize_gui() + self._initialize_gui() - if data_path is not None: - pass + self._file_dialog_options = 0 + if self.nonnative: + self._file_dialog_options = QFileDialog.DontUseNativeDialog - if import_data is not None: - self.importData(import_data) + if labels_path: + self.loadProject(labels_path) - # TODO: auto-add video to clean project if no data provided - # TODO: auto-select video if data provided, or add it to project - if video is not None: - self.addVideo(video) + def event(self, e: QEvent) -> bool: + """Custom event handler. - def changestack_push(self, change=None): - """Add to stack of changes made by user.""" + We use this to ignore events that would clear status bar. + + Args: + e: The event. + Returns: + True if we ignore event, otherwise returns whatever the usual + event handler would return. + """ + if e.type() == QEvent.StatusTip: + if e.tip() == "": + return True + return super().event(e) + + def changestack_push(self, change: bool = None): + """Adds to stack of changes made by user.""" # Currently the change doesn't store any data, and we're only using this # to determine if there are unsaved changes. Eventually we could use this # to support undo/redo. self._change_stack.append(change) def changestack_savepoint(self): + """Marks that project was just saved.""" self.changestack_push("SAVE") def changestack_clear(self): + """Clears stack of changes.""" self._change_stack = list() - def changestack_start_atomic(self, change=None): - pass + def changestack_start_atomic(self): + """Marks that we want to track a set of changes as a single change.""" + self.changestack_push("ATOMIC_START") def changestack_end_atomic(self): - pass + """Marks that we want finished the set of changes to track together.""" + self.changestack_push("ATOMIC_END") def changestack_has_changes(self) -> bool: + """Returns whether there are any unsaved changes.""" # True iff there are no unsaved changed - if len(self._change_stack) == 0: return False - if self._change_stack[-1] == "SAVE": return False + if len(self._change_stack) == 0: + return False + if self._change_stack[-1] == "SAVE": + return False return True @property def filename(self): + """Returns filename for current project.""" return self._filename @filename.setter def filename(self, x): + """Sets filename for current project. Doesn't load file.""" self._filename = x - if x is not None: self.setWindowTitle(x) + if x is not None: + self.setWindowTitle(x) - def initialize_gui(self): + def _initialize_gui(self): + """Creates menus, dock windows, starts timers to update gui state.""" - shortcut_yaml = resource_filename(Requirement.parse("sleap"),"sleap/config/shortcuts.yaml") - with open(shortcut_yaml, 'r') as f: - shortcuts = yaml.load(f, Loader=yaml.SafeLoader) + self._create_video_player() + self.statusBar() + self.load_overlays() + self._create_menus() + self._create_dock_windows() - for action in shortcuts: - key_string = shortcuts.get(action, None) - key_string = "" if key_string is None else key_string - if "." in key_string: - shortcuts[action] = eval(key_string) + # Create timer to update state of gui at regular intervals + self.update_gui_timer = QtCore.QTimer() + self.update_gui_timer.timeout.connect(self._update_gui_state) + self.update_gui_timer.start(0.1) - ####### Video player ####### + def _create_video_player(self): + """Creates and connects :class:`QtVideoPlayer` for gui.""" self.player = QtVideoPlayer(color_manager=self._color_manager) - self.player.changedPlot.connect(self.newFrame) - self.player.changedData.connect(lambda inst: self.changestack_push("viewer change")) + self.player.changedPlot.connect(self._after_plot_update) + self.player.changedData.connect( + lambda inst: self.changestack_push("viewer change") + ) self.player.view.instanceDoubleClicked.connect(self.doubleClickInstance) self.player.seekbar.selectionChanged.connect(lambda: self.updateStatusMessage()) self.setCentralWidget(self.player) - ####### Status bar ####### - self.statusBar() # Initialize status bar - - self.load_overlays() + def _create_menus(self): + """Creates main application menus.""" + shortcuts = Shortcuts() - ####### Menus ####### + def _menu_item(menu, key: str, name: str, action: Callable): + menu_item = menu.addAction(name, action, shortcuts[key]) + self._menu_actions[key] = menu_item ### File Menu ### fileMenu = self.menuBar().addMenu("File") - self._menu_actions["new"] = fileMenu.addAction("&New Project", self.newProject, shortcuts["new"]) - self._menu_actions["open"] = fileMenu.addAction("&Open Project...", self.openProject, shortcuts["open"]) + _menu_item(fileMenu, "new", "New Project", self.newProject) + _menu_item(fileMenu, "open", "Open Project...", self.openProject) + _menu_item( + fileMenu, "import predictions", "Import Labels...", self.importPredictions + ) + fileMenu.addSeparator() - self._menu_actions["add videos"] = fileMenu.addAction("Add Videos...", self.addVideo, shortcuts["add videos"]) + _menu_item(fileMenu, "add videos", "Add Videos...", self.addVideo) + fileMenu.addSeparator() - self._menu_actions["save"] = fileMenu.addAction("&Save", self.saveProject, shortcuts["save"]) - self._menu_actions["save as"] = fileMenu.addAction("Save As...", self.saveProjectAs, shortcuts["save as"]) + _menu_item(fileMenu, "save", "Save", self.saveProject) + _menu_item(fileMenu, "save as", "Save As...", self.saveProjectAs) + fileMenu.addSeparator() - self._menu_actions["close"] = fileMenu.addAction("Quit", self.close, shortcuts["close"]) + _menu_item(fileMenu, "close", "Quit", self.close) ### Go Menu ### goMenu = self.menuBar().addMenu("Go") - self._menu_actions["goto next"] = goMenu.addAction("Next Labeled Frame", self.nextLabeledFrame, shortcuts["goto next"]) - self._menu_actions["goto prev"] = goMenu.addAction("Previous Labeled Frame", self.previousLabeledFrame, shortcuts["goto prev"]) - - self._menu_actions["goto next user"] = goMenu.addAction("Next User Labeled Frame", self.nextUserLabeledFrame, shortcuts["goto next user"]) - - self._menu_actions["goto next suggestion"] = goMenu.addAction("Next Suggestion", self.nextSuggestedFrame, shortcuts["goto next suggestion"]) - self._menu_actions["goto prev suggestion"] = goMenu.addAction("Previous Suggestion", lambda:self.nextSuggestedFrame(-1), shortcuts["goto prev suggestion"]) - - self._menu_actions["goto next track"] = goMenu.addAction("Next Track Spawn Frame", self.nextTrackFrame, shortcuts["goto next track"]) + _menu_item( + goMenu, "goto next labeled", "Next Labeled Frame", self.nextLabeledFrame + ) + _menu_item( + goMenu, + "goto prev labeled", + "Previous Labeled Frame", + self.previousLabeledFrame, + ) + _menu_item( + goMenu, + "goto next user", + "Next User Labeled Frame", + self.nextUserLabeledFrame, + ) + _menu_item( + goMenu, "goto next suggestion", "Next Suggestion", self.nextSuggestedFrame + ) + _menu_item( + goMenu, + "goto prev suggestion", + "Previous Suggestion", + lambda: self.nextSuggestedFrame(-1), + ) + _menu_item( + goMenu, + "goto next track spawn", + "Next Track Spawn Frame", + self.nextTrackFrame, + ) goMenu.addSeparator() - self._menu_actions["next video"] = goMenu.addAction("Next Video", self.nextVideo, shortcuts["next video"]) - self._menu_actions["prev video"] = goMenu.addAction("Previous Video", self.previousVideo, shortcuts["prev video"]) + _menu_item(goMenu, "next video", "Next Video", self.nextVideo) + _menu_item(goMenu, "prev video", "Previous Video", self.previousVideo) goMenu.addSeparator() - self._menu_actions["goto frame"] = goMenu.addAction("Go to Frame...", self.gotoFrame, shortcuts["goto frame"]) - self._menu_actions["mark frame"] = goMenu.addAction("Mark Frame", self.markFrame, shortcuts["mark frame"]) - self._menu_actions["goto marked"] = goMenu.addAction("Go to Marked Frame", self.goMarkedFrame, shortcuts["goto marked"]) - + _menu_item(goMenu, "goto frame", "Go to Frame...", self.gotoFrame) ### View Menu ### viewMenu = self.menuBar().addMenu("View") + self.viewMenu = viewMenu # store as attribute so docks can add items viewMenu.addSeparator() - self._menu_actions["color predicted"] = viewMenu.addAction("Color Predicted Instances", self.toggleColorPredicted, shortcuts["color predicted"]) + _menu_item( + viewMenu, + "color predicted", + "Color Predicted Instances", + self.toggleColorPredicted, + ) self.paletteMenu = viewMenu.addMenu("Color Palette") for palette_name in self._color_manager.palette_names: - menu_item = self.paletteMenu.addAction(f"{palette_name}", - lambda x=palette_name: self.setPalette(x)) + menu_item = self.paletteMenu.addAction( + f"{palette_name}", lambda x=palette_name: self.setPalette(x) + ) menu_item.setCheckable(True) self.setPalette("standard") viewMenu.addSeparator() - self._menu_actions["show labels"] = viewMenu.addAction("Show Node Names", self.toggleLabels, shortcuts["show labels"]) - self._menu_actions["show edges"] = viewMenu.addAction("Show Edges", self.toggleEdges, shortcuts["show edges"]) - self._menu_actions["show trails"] = viewMenu.addAction("Show Trails", self.toggleTrails, shortcuts["show trails"]) + self.seekbarHeaderMenu = viewMenu.addMenu("Seekbar Header") + headers = ( + "None", + "Point Displacement (sum)", + "Point Displacement (max)", + "Instance Score (sum)", + "Instance Score (min)", + "Point Score (sum)", + "Point Score (min)", + "Number of predicted points", + ) + for header in headers: + menu_item = self.seekbarHeaderMenu.addAction( + header, lambda x=header: self.setSeekbarHeader(x) + ) + menu_item.setCheckable(True) + self.setSeekbarHeader("None") + + viewMenu.addSeparator() + + _menu_item(viewMenu, "show labels", "Show Node Names", self.toggleLabels) + _menu_item(viewMenu, "show edges", "Show Edges", self.toggleEdges) + _menu_item(viewMenu, "show trails", "Show Trails", self.toggleTrails) self.trailLengthMenu = viewMenu.addMenu("Trail Length") for length_option in (4, 10, 20): - menu_item = self.trailLengthMenu.addAction(f"{length_option}", - lambda x=length_option: self.setTrailLength(x)) + menu_item = self.trailLengthMenu.addAction( + f"{length_option}", lambda x=length_option: self.setTrailLength(x) + ) menu_item.setCheckable(True) viewMenu.addSeparator() - self._menu_actions["fit"] = viewMenu.addAction("Fit Instances to View", self.toggleAutoZoom, shortcuts["fit"]) + _menu_item(viewMenu, "fit", "Fit Instances to View", self.toggleAutoZoom) viewMenu.addSeparator() # set menu checkmarks - self._menu_actions["show labels"].setCheckable(True); self._menu_actions["show labels"].setChecked(self._show_labels) - self._menu_actions["show edges"].setCheckable(True); self._menu_actions["show edges"].setChecked(self._show_edges) - self._menu_actions["show trails"].setCheckable(True); self._menu_actions["show trails"].setChecked(self.overlays["trails"].show) - self._menu_actions["color predicted"].setCheckable(True); self._menu_actions["color predicted"].setChecked(self.overlays["instance"].color_predicted) + self._menu_actions["show labels"].setCheckable(True) + self._menu_actions["show labels"].setChecked(self._show_labels) + self._menu_actions["show edges"].setCheckable(True) + self._menu_actions["show edges"].setChecked(self._show_edges) + self._menu_actions["show trails"].setCheckable(True) + self._menu_actions["show trails"].setChecked(self.overlays["trails"].show) + self._menu_actions["color predicted"].setCheckable(True) + self._menu_actions["color predicted"].setChecked( + self.overlays["instance"].color_predicted + ) self._menu_actions["fit"].setCheckable(True) ### Label Menu ### labelMenu = self.menuBar().addMenu("Labels") - self._menu_actions["add instance"] = labelMenu.addAction("Add Instance", self.newInstance, shortcuts["add instance"]) - self._menu_actions["delete instance"] = labelMenu.addAction("Delete Instance", self.deleteSelectedInstance, shortcuts["delete instance"]) + _menu_item(labelMenu, "add instance", "Add Instance", self.newInstance) + _menu_item( + labelMenu, "delete instance", "Delete Instance", self.deleteSelectedInstance + ) labelMenu.addSeparator() self.track_menu = labelMenu.addMenu("Set Instance Track") - self._menu_actions["transpose"] = labelMenu.addAction("Transpose Instance Tracks", self.transposeInstance, shortcuts["transpose"]) - self._menu_actions["delete track"] = labelMenu.addAction("Delete Instance and Track", self.deleteSelectedInstanceTrack, shortcuts["delete track"]) + _menu_item( + labelMenu, "transpose", "Transpose Instance Tracks", self.transposeInstance + ) + _menu_item( + labelMenu, + "delete track", + "Delete Instance and Track", + self.deleteSelectedInstanceTrack, + ) labelMenu.addSeparator() - self._menu_actions["select next"] = labelMenu.addAction("Select Next Instance", self.player.view.nextSelection, shortcuts["select next"]) - self._menu_actions["clear selection"] = labelMenu.addAction("Clear Selection", self.player.view.clearSelection, shortcuts["clear selection"]) + _menu_item( + labelMenu, + "select next", + "Select Next Instance", + self.player.view.nextSelection, + ) + _menu_item( + labelMenu, + "clear selection", + "Clear Selection", + self.player.view.clearSelection, + ) labelMenu.addSeparator() ### Predict Menu ### predictionMenu = self.menuBar().addMenu("Predict") - self._menu_actions["active learning"] = predictionMenu.addAction("Run Active Learning...", self.runActiveLearning, shortcuts["learning"]) - self._menu_actions["inference"] = predictionMenu.addAction("Run Inference...", self.runInference) - self._menu_actions["learning expert"] = predictionMenu.addAction("Expert Controls...", self.runLearningExpert) + + _menu_item( + predictionMenu, + "active learning", + "Run Active Learning...", + lambda: self.showLearningDialog("learning"), + ) + _menu_item( + predictionMenu, + "inference", + "Run Inference...", + lambda: self.showLearningDialog("inference"), + ) + _menu_item( + predictionMenu, + "learning expert", + "Expert Controls...", + lambda: self.showLearningDialog("expert"), + ) + predictionMenu.addSeparator() - self._menu_actions["negative sample"] = predictionMenu.addAction("Mark Negative Training Sample...", self.markNegativeAnchor) + _menu_item( + predictionMenu, + "negative sample", + "Mark Negative Training Sample...", + self.markNegativeAnchor, + ) + _menu_item( + predictionMenu, + "clear negative samples", + "Clear Current Frame Negative Samples", + self.clearFrameNegativeAnchors, + ) + predictionMenu.addSeparator() - self._menu_actions["visualize models"] = predictionMenu.addAction("Visualize Model Outputs...", self.visualizeOutputs) - self._menu_actions["import predictions"] = predictionMenu.addAction("Import Predictions...", self.importPredictions) + _menu_item( + predictionMenu, + "visualize models", + "Visualize Model Outputs...", + self.visualizeOutputs, + ) + predictionMenu.addSeparator() - self._menu_actions["remove predictions"] = predictionMenu.addAction("Delete All Predictions...", self.deletePredictions) - self._menu_actions["remove clip predictions"] = predictionMenu.addAction("Delete Predictions from Clip...", self.deleteClipPredictions, shortcuts["delete clip"]) - self._menu_actions["remove area predictions"] = predictionMenu.addAction("Delete Predictions from Area...", self.deleteAreaPredictions, shortcuts["delete area"]) - self._menu_actions["remove score predictions"] = predictionMenu.addAction("Delete Predictions with Low Score...", self.deleteLowScorePredictions) - self._menu_actions["remove frame limit predictions"] = predictionMenu.addAction("Delete Predictions beyond Frame Limit...", self.deleteFrameLimitPredictions) + _menu_item( + predictionMenu, + "remove predictions", + "Delete All Predictions...", + self.deletePredictions, + ) + _menu_item( + predictionMenu, + "remove clip predictions", + "Delete Predictions from Clip...", + self.deleteClipPredictions, + ) + _menu_item( + predictionMenu, + "remove area predictions", + "Delete Predictions from Area...", + self.deleteAreaPredictions, + ) + _menu_item( + predictionMenu, + "remove score predictions", + "Delete Predictions with Low Score...", + self.deleteLowScorePredictions, + ) + _menu_item( + predictionMenu, + "remove frame limit predictions", + "Delete Predictions beyond Frame Limit...", + self.deleteFrameLimitPredictions, + ) + predictionMenu.addSeparator() - self._menu_actions["export frames"] = predictionMenu.addAction("Export Training Package...", self.exportLabeledFrames) - self._menu_actions["export clip"] = predictionMenu.addAction("Export Labeled Clip...", self.exportLabeledClip, shortcuts["export clip"]) + _menu_item( + predictionMenu, + "export frames", + "Export Training Package...", + self.exportLabeledFrames, + ) + _menu_item( + predictionMenu, + "export clip", + "Export Labeled Clip...", + self.exportLabeledClip, + ) ############ helpMenu = self.menuBar().addMenu("Help") - helpMenu.addAction("Documentation", self.openDocumentation) helpMenu.addAction("Keyboard Reference", self.openKeyRef) - helpMenu.addAction("About", self.openAbout) - ####### Helpers ####### + def _create_dock_windows(self): + """Create dock windows and connects them to gui.""" + def _make_dock(name, widgets=[], tab_with=None): dock = QDockWidget(name) dock.setAllowedAreas(Qt.LeftDockWidgetArea | Qt.RightDockWidgetArea) @@ -277,7 +494,7 @@ def _make_dock(name, widgets=[], tab_with=None): dock_widget.setLayout(layout) dock.setWidget(dock_widget) self.addDockWidget(Qt.RightDockWidgetArea, dock) - viewMenu.addAction(dock.toggleViewAction()) + self.viewMenu.addAction(dock.toggleViewAction()) if tab_with is not None: self.tabifyDockWidget(tab_with, dock) return layout @@ -288,20 +505,26 @@ def _make_dock(name, widgets=[], tab_with=None): videos_layout.addWidget(self.videosTable) hb = QHBoxLayout() btn = QPushButton("Show video") - btn.clicked.connect(self.activateSelectedVideo); hb.addWidget(btn) + btn.clicked.connect(self.activateSelectedVideo) + hb.addWidget(btn) self._buttons["show video"] = btn btn = QPushButton("Add videos") - btn.clicked.connect(self.addVideo); hb.addWidget(btn) + btn.clicked.connect(self.addVideo) + hb.addWidget(btn) btn = QPushButton("Remove video") - btn.clicked.connect(self.removeVideo); hb.addWidget(btn) + btn.clicked.connect(self.removeVideo) + hb.addWidget(btn) self._buttons["remove video"] = btn - hbw = QWidget(); hbw.setLayout(hb) + hbw = QWidget() + hbw.setLayout(hb) videos_layout.addWidget(hbw) self.videosTable.doubleClicked.connect(self.activateSelectedVideo) ####### Skeleton ####### - skeleton_layout = _make_dock("Skeleton", tab_with=videos_layout.parent().parent()) + skeleton_layout = _make_dock( + "Skeleton", tab_with=videos_layout.parent().parent() + ) gb = QGroupBox("Nodes") vb = QVBoxLayout() @@ -309,44 +532,63 @@ def _make_dock(name, widgets=[], tab_with=None): vb.addWidget(self.skeletonNodesTable) hb = QHBoxLayout() btn = QPushButton("New node") - btn.clicked.connect(self.newNode); hb.addWidget(btn) + btn.clicked.connect(self.newNode) + hb.addWidget(btn) btn = QPushButton("Delete node") - btn.clicked.connect(self.deleteNode); hb.addWidget(btn) + btn.clicked.connect(self.deleteNode) + hb.addWidget(btn) self._buttons["delete node"] = btn - hbw = QWidget(); hbw.setLayout(hb) + hbw = QWidget() + hbw.setLayout(hb) vb.addWidget(hbw) gb.setLayout(vb) skeleton_layout.addWidget(gb) + def _update_edge_src(): + self.skeletonEdgesDst.model().skeleton = self.skeleton + gb = QGroupBox("Edges") vb = QVBoxLayout() self.skeletonEdgesTable = SkeletonEdgesTable(self.skeleton) vb.addWidget(self.skeletonEdgesTable) hb = QHBoxLayout() - self.skeletonEdgesSrc = QComboBox(); self.skeletonEdgesSrc.setEditable(False); self.skeletonEdgesSrc.currentIndexChanged.connect(self.selectSkeletonEdgeSrc) + self.skeletonEdgesSrc = QComboBox() + self.skeletonEdgesSrc.setEditable(False) + self.skeletonEdgesSrc.currentIndexChanged.connect(_update_edge_src) self.skeletonEdgesSrc.setModel(SkeletonNodeModel(self.skeleton)) hb.addWidget(self.skeletonEdgesSrc) hb.addWidget(QLabel("to")) - self.skeletonEdgesDst = QComboBox(); self.skeletonEdgesDst.setEditable(False) + self.skeletonEdgesDst = QComboBox() + self.skeletonEdgesDst.setEditable(False) hb.addWidget(self.skeletonEdgesDst) - self.skeletonEdgesDst.setModel(SkeletonNodeModel(self.skeleton, lambda: self.skeletonEdgesSrc.currentText())) + self.skeletonEdgesDst.setModel( + SkeletonNodeModel( + self.skeleton, lambda: self.skeletonEdgesSrc.currentText() + ) + ) btn = QPushButton("Add edge") - btn.clicked.connect(self.newEdge); hb.addWidget(btn) + btn.clicked.connect(self.newEdge) + hb.addWidget(btn) self._buttons["add edge"] = btn btn = QPushButton("Delete edge") - btn.clicked.connect(self.deleteEdge); hb.addWidget(btn) + btn.clicked.connect(self.deleteEdge) + hb.addWidget(btn) self._buttons["delete edge"] = btn - hbw = QWidget(); hbw.setLayout(hb) + hbw = QWidget() + hbw.setLayout(hb) vb.addWidget(hbw) gb.setLayout(vb) skeleton_layout.addWidget(gb) hb = QHBoxLayout() btn = QPushButton("Load Skeleton") - btn.clicked.connect(self.openSkeleton); hb.addWidget(btn) + btn.clicked.connect(self.openSkeleton) + hb.addWidget(btn) btn = QPushButton("Save Skeleton") - btn.clicked.connect(self.saveSkeleton); hb.addWidget(btn) - hbw = QWidget(); hbw.setLayout(hb) + btn.clicked.connect(self.saveSkeleton) + hb.addWidget(btn) + hbw = QWidget() + hbw.setLayout(hb) skeleton_layout.addWidget(hbw) # update edge UI when change to nodes @@ -359,20 +601,32 @@ def _make_dock(name, widgets=[], tab_with=None): instances_layout.addWidget(self.instancesTable) hb = QHBoxLayout() btn = QPushButton("New instance") - btn.clicked.connect(lambda x: self.newInstance()); hb.addWidget(btn) + btn.clicked.connect(lambda x: self.newInstance()) + hb.addWidget(btn) btn = QPushButton("Delete instance") - btn.clicked.connect(self.deleteSelectedInstance); hb.addWidget(btn) + btn.clicked.connect(self.deleteSelectedInstance) + hb.addWidget(btn) self._buttons["delete instance"] = btn - hbw = QWidget(); hbw.setLayout(hb) + hbw = QWidget() + hbw.setLayout(hb) instances_layout.addWidget(hbw) def update_instance_table_selection(): - cur_video_instance = self.player.view.getSelection() - if cur_video_instance is None: cur_video_instance = -1 - table_index = self.instancesTable.model().createIndex(cur_video_instance, 0) - self.instancesTable.setCurrentIndex(table_index) + inst_selected = self.player.view.getSelectionInstance() - self.instancesTable.selectionChangedSignal.connect(lambda row: self.player.view.selectInstance(row, from_all=True, signal=False)) + if not inst_selected: + return + + idx = -1 + if inst_selected in self.labeled_frame.instances_to_show: + idx = self.labeled_frame.instances_to_show.index(inst_selected) + + table_row_idx = self.instancesTable.model().createIndex(idx, 0) + self.instancesTable.setCurrentIndex(table_row_idx) + + self.instancesTable.selectionChangedSignal.connect( + lambda inst: self.player.view.selectInstance(inst, signal=False) + ) self.player.view.updatedSelection.connect(update_instance_table_selection) # update track UI when change to track name @@ -386,50 +640,58 @@ def update_instance_table_selection(): hb = QHBoxLayout() btn = QPushButton("Prev") - btn.clicked.connect(lambda:self.nextSuggestedFrame(-1)); hb.addWidget(btn) + btn.clicked.connect(lambda: self.nextSuggestedFrame(-1)) + hb.addWidget(btn) self.suggested_count_label = QLabel() hb.addWidget(self.suggested_count_label) btn = QPushButton("Next") - btn.clicked.connect(lambda:self.nextSuggestedFrame()); hb.addWidget(btn) - hbw = QWidget(); hbw.setLayout(hb) + btn.clicked.connect(lambda: self.nextSuggestedFrame()) + hb.addWidget(btn) + hbw = QWidget() + hbw.setLayout(hb) suggestions_layout.addWidget(hbw) - suggestions_yaml = resource_filename(Requirement.parse("sleap"),"sleap/config/suggestions.yaml") - form_wid = YamlFormWidget(yaml_file=suggestions_yaml, title="Generate Suggestions") + suggestions_yaml = resource_filename( + Requirement.parse("sleap"), "sleap/config/suggestions.yaml" + ) + form_wid = YamlFormWidget( + yaml_file=suggestions_yaml, title="Generate Suggestions" + ) form_wid.mainAction.connect(self.generateSuggestions) suggestions_layout.addWidget(form_wid) - self.suggestionsTable.doubleClicked.connect(lambda table_idx: self.gotoVideoAndFrame(*self.labels.get_suggestions()[table_idx.row()])) - - # - # Set timer to update state of gui at regular intervals - # - self.update_gui_timer = QtCore.QTimer() - self.update_gui_timer.timeout.connect(self.update_gui_state) - self.update_gui_timer.start(0.1) + self.suggestionsTable.doubleClicked.connect( + lambda table_idx: self.gotoVideoAndFrame( + *self.labels.get_suggestions()[table_idx.row()] + ) + ) def load_overlays(self): - self.overlays["trails"] = TrackTrailOverlay( - labels = self.labels, - scene = self.player.view.scene, - color_manager = self._color_manager) - - self.overlays["instance"] = InstanceOverlay( - labels = self.labels, - player = self.player, - color_manager = self._color_manager) - - def update_gui_state(self): - has_selected_instance = (self.player.view.getSelection() is not None) + """Load all standard video overlays.""" + self.overlays["track_labels"] = TrackListOverlay(self.labels, self.player) + self.overlays["negative"] = NegativeAnchorOverlay(self.labels, self.player) + self.overlays["trails"] = TrackTrailOverlay(self.labels, self.player) + self.overlays["instance"] = InstanceOverlay(self.labels, self.player) + + def _update_gui_state(self): + """Enable/disable gui items based on current state.""" + has_selected_instance = self.player.view.getSelection() is not None has_unsaved_changes = self.changestack_has_changes() - has_multiple_videos = (self.labels is not None and len(self.labels.videos) > 1) - has_labeled_frames = any((lf.video == self.video for lf in self.labels)) - has_suggestions = (len(self.labels.suggestions) > 0) - has_tracks = (len(self.labels.tracks) > 0) - has_multiple_instances = (self.labeled_frame is not None and len(self.labeled_frame.instances) > 1) + has_multiple_videos = self.labels is not None and len(self.labels.videos) > 1 + has_labeled_frames = self.labels is not None and any( + (lf.video == self.video for lf in self.labels) + ) + has_suggestions = self.labels is not None and (len(self.labels.suggestions) > 0) + has_tracks = self.labels is not None and (len(self.labels.tracks) > 0) + has_multiple_instances = ( + self.labeled_frame is not None and len(self.labeled_frame.instances) > 1 + ) # todo: exclude predicted instances from count - has_nodes_selected = (self.skeletonEdgesSrc.currentIndex() > -1 and - self.skeletonEdgesDst.currentIndex() > -1) + has_nodes_selected = ( + self.skeletonEdgesSrc.currentIndex() > -1 + and self.skeletonEdgesDst.currentIndex() > -1 + ) + control_key_down = QApplication.queryKeyboardModifiers() == Qt.ControlModifier # Update menus @@ -440,60 +702,93 @@ def update_gui_state(self): self._menu_actions["transpose"].setEnabled(has_multiple_instances) self._menu_actions["save"].setEnabled(has_unsaved_changes) - self._menu_actions["goto marked"].setEnabled(self.mark_idx is not None) self._menu_actions["next video"].setEnabled(has_multiple_videos) self._menu_actions["prev video"].setEnabled(has_multiple_videos) - self._menu_actions["goto next"].setEnabled(has_labeled_frames) - self._menu_actions["goto prev"].setEnabled(has_labeled_frames) + self._menu_actions["goto next labeled"].setEnabled(has_labeled_frames) + self._menu_actions["goto prev labeled"].setEnabled(has_labeled_frames) self._menu_actions["goto next suggestion"].setEnabled(has_suggestions) self._menu_actions["goto prev suggestion"].setEnabled(has_suggestions) - self._menu_actions["goto next track"].setEnabled(has_tracks) + self._menu_actions["goto next track spawn"].setEnabled(has_tracks) # Update buttons self._buttons["add edge"].setEnabled(has_nodes_selected) - self._buttons["delete edge"].setEnabled(self.skeletonEdgesTable.currentIndex().isValid()) - self._buttons["delete node"].setEnabled(self.skeletonNodesTable.currentIndex().isValid()) - self._buttons["show video"].setEnabled(self.videosTable.currentIndex().isValid()) - self._buttons["remove video"].setEnabled(self.videosTable.currentIndex().isValid()) - self._buttons["delete instance"].setEnabled(self.instancesTable.currentIndex().isValid()) - - def update_data_views(self): - self.videosTable.model().videos = self.labels.videos - - self.skeletonNodesTable.model().skeleton = self.skeleton - self.skeletonEdgesTable.model().skeleton = self.skeleton - self.skeletonEdgesSrc.model().skeleton = self.skeleton - self.skeletonEdgesDst.model().skeleton = self.skeleton - - self.instancesTable.model().labels = self.labels - self.instancesTable.model().labeled_frame = self.labeled_frame - self.instancesTable.model().color_manager = self._color_manager - - self.suggestionsTable.model().labels = self.labels - - # update count of suggested frames w/ labeled instances - suggestion_status_text = "" - suggestion_list = self.labels.get_suggestions() - if len(suggestion_list): - suggestion_label_counts = [self.labels.instance_count(video, frame_idx) - for (video, frame_idx) in suggestion_list] - labeled_count = len(suggestion_list) - suggestion_label_counts.count(0) - suggestion_status_text = f"{labeled_count}/{len(suggestion_list)} labeled" - self.suggested_count_label.setText(suggestion_status_text) - - def keyPressEvent(self, event: QKeyEvent): - if event.key() == Qt.Key_Q: - self.close() - else: - event.ignore() # Kicks the event up to parent + self._buttons["delete edge"].setEnabled( + self.skeletonEdgesTable.currentIndex().isValid() + ) + self._buttons["delete node"].setEnabled( + self.skeletonNodesTable.currentIndex().isValid() + ) + self._buttons["show video"].setEnabled( + self.videosTable.currentIndex().isValid() + ) + self._buttons["remove video"].setEnabled( + self.videosTable.currentIndex().isValid() + ) + self._buttons["delete instance"].setEnabled( + self.instancesTable.currentIndex().isValid() + ) + + # Update overlays + self.overlays["track_labels"].visible = ( + control_key_down and has_selected_instance + ) + + def _update_data_views(self, *update): + """Update data used by data view table models. + + Args: + Accepts names of what data to update as unnamed string arguments: + "video", "skeleton", "labels", "frame", "suggestions" + If no arguments are given, then everything is updated. + Returns: + None. + """ + update = update or ("video", "skeleton", "labels", "frame", "suggestions") + + if len(self.skeleton.nodes) == 0 and len(self.labels.skeletons): + self.skeleton = self.labels.skeletons[0] + + if "video" in update: + self.videosTable.model().items = self.labels.videos + + if "skeleton" in update: + self.skeletonNodesTable.model().skeleton = self.skeleton + self.skeletonEdgesTable.model().skeleton = self.skeleton + self.skeletonEdgesSrc.model().skeleton = self.skeleton + self.skeletonEdgesDst.model().skeleton = self.skeleton + + if "labels" in update: + self.instancesTable.model().labels = self.labels + self.instancesTable.model().color_manager = self._color_manager + + if "frame" in update: + self.instancesTable.model().labeled_frame = self.labeled_frame + + if "suggestions" in update: + self.suggestionsTable.model().labels = self.labels + + # update count of suggested frames w/ labeled instances + suggestion_status_text = "" + suggestion_list = self.labels.get_suggestions() + if len(suggestion_list): + suggestion_label_counts = [ + self.labels.instance_count(video, frame_idx) + for (video, frame_idx) in suggestion_list + ] + labeled_count = len(suggestion_list) - suggestion_label_counts.count(0) + suggestion_status_text = ( + f"{labeled_count}/{len(suggestion_list)} labeled" + ) + self.suggested_count_label.setText(suggestion_status_text) def plotFrame(self, *args, **kwargs): - """Wrap call to player.plot so we can redraw/update things.""" - if self.video is None: return + """Plots (or replots) current frame.""" + if self.video is None: + return self.player.plot(*args, **kwargs) self.player.showLabels(self._show_labels) @@ -501,13 +796,69 @@ def plotFrame(self, *args, **kwargs): if self._auto_zoom: self.player.zoomToFit() - def importData(self, filename=None, do_load=True): + def _after_plot_update(self, player, frame_idx, selected_inst): + """Called each time a new frame is drawn.""" + + # Store the current LabeledFrame (or make new, empty object) + self.labeled_frame = self.labels.find(self.video, frame_idx, return_new=True)[0] + + # Show instances, etc, for this frame + for overlay in self.overlays.values(): + overlay.add_to_scene(self.video, frame_idx) + + # Select instance if there was already selection + if selected_inst is not None: + player.view.selectInstance(selected_inst) + + # Update related displays + self.updateStatusMessage() + self._update_data_views("frame") + + # Trigger event after the overlays have been added + player.view.updatedViewer.emit() + + def updateStatusMessage(self, message: Optional[str] = None): + """Updates status bar.""" + if message is None: + message = f"Frame: {self.player.frame_idx+1}/{len(self.video)}" + if self.player.seekbar.hasSelection(): + start, end = self.player.seekbar.getSelection() + message += f" (selection: {start}-{end})" + + if len(self.labels.videos) > 1: + message += f" of video {self.labels.videos.index(self.video)}" + + message += f" Labeled Frames: " + if self.video is not None: + message += ( + f"{len(self.labels.get_video_user_labeled_frames(self.video))}" + ) + if len(self.labels.videos) > 1: + message += " in video, " + if len(self.labels.videos) > 1: + message += f"{len(self.labels.user_labeled_frames)} in project" + + self.statusBar().showMessage(message) + + def loadProject(self, filename: Optional[str] = None): + """ + Loads given labels file into GUI. + + Args: + filename: The path to the saved labels dataset. If None, + then don't do anything. + + Returns: + None: + """ show_msg = False - if len(filename) == 0: return + if len(filename) == 0: + return gui_video_callback = Labels.make_gui_video_callback( - search_paths=[os.path.dirname(filename)]) + search_paths=[os.path.dirname(filename)] + ) has_loaded = False labels = None @@ -523,55 +874,68 @@ def importData(self, filename=None, do_load=True): print(e) QMessageBox(text=f"Unable to load {filename}.").exec_() - if do_load: + self.labels = labels + self.filename = filename - self.labels = labels - self.filename = filename - - if has_loaded: - self.changestack_clear() - self._color_manager.labels = self.labels - self._color_manager.set_palette(self._color_palette) + if has_loaded: + self.changestack_clear() + self._color_manager.labels = self.labels + self._color_manager.set_palette(self._color_palette) - self.load_overlays() + self.load_overlays() - self.setTrailLength(self.overlays["trails"].trail_length) + self.setTrailLength(self.overlays["trails"].trail_length) - if show_msg: - msgBox = QMessageBox(text=f"Imported {len(self.labels)} labeled frames.") - msgBox.exec_() + if show_msg: + msgBox = QMessageBox( + text=f"Imported {len(self.labels)} labeled frames." + ) + msgBox.exec_() - if len(self.labels.skeletons): - # TODO: add support for multiple skeletons - self.skeleton = self.labels.skeletons[0] + if len(self.labels.skeletons): + # TODO: add support for multiple skeletons + self.skeleton = self.labels.skeletons[0] - # Update UI tables - self.update_data_views() + # Update UI tables + self._update_data_views() - # Load first video + # Load first video + if len(self.labels.videos): self.loadVideo(self.labels.videos[0], 0) - # Update track menu options - self.updateTrackMenu() - else: - return labels + # Update track menu options + self.updateTrackMenu() def updateTrackMenu(self): + """Updates track menu options.""" self.track_menu.clear() for track in self.labels.tracks: key_command = "" if self.labels.tracks.index(track) < 9: key_command = Qt.CTRL + Qt.Key_0 + self.labels.tracks.index(track) + 1 - self.track_menu.addAction(f"{track.name}", lambda x=track:self.setInstanceTrack(x), key_command) + self.track_menu.addAction( + f"{track.name}", lambda x=track: self.setInstanceTrack(x), key_command + ) self.track_menu.addAction("New Track", self.addTrack, Qt.CTRL + Qt.Key_0) def activateSelectedVideo(self, x): + """Activates video selected in table.""" # Get selected video idx = self.videosTable.currentIndex() - if not idx.isValid(): return + if not idx.isValid(): + return self.loadVideo(self.labels.videos[idx.row()], idx.row()) - def addVideo(self, filename=None): + def addVideo(self, filename: Optional[str] = None): + """Shows gui for adding video to project. + + Args: + filename: If given, then we just load this video. If not given, + then we show dialog for importing videos. + + Returns: + None. + """ # Browse for file video = None if isinstance(filename, str): @@ -589,15 +953,17 @@ def addVideo(self, filename=None): # Load if no video currently loaded if self.video is None: - self.loadVideo(video, len(self.labels.videos)-1) + self.loadVideo(video, len(self.labels.videos) - 1) # Update data model/view - self.update_data_views() + self._update_data_views("video") def removeVideo(self): + """Removes video (selected in table) from project.""" # Get selected video idx = self.videosTable.currentIndex() - if not idx.isValid(): return + if not idx.isValid(): + return video = self.labels.videos[idx.row()] # Count labeled frames for this video @@ -605,7 +971,14 @@ def removeVideo(self): # Warn if there are labels that will be deleted if n > 0: - response = QMessageBox.critical(self, "Removing video with labels", f"{n} labeled frames in this video will be deleted, are you sure you want to remove this video?", QMessageBox.Yes, QMessageBox.No) + response = QMessageBox.critical( + self, + "Removing video with labels", + f"{n} labeled frames in this video will be deleted, " + "are you sure you want to remove this video?", + QMessageBox.Yes, + QMessageBox.No, + ) if response == QMessageBox.No: return @@ -614,7 +987,7 @@ def removeVideo(self): self.changestack_push("remove video") # Update data model - self.update_data_views() + self._update_data_views() # Update view if this was the current video if self.video == video: @@ -625,13 +998,14 @@ def removeVideo(self): new_idx = min(idx.row(), len(self.labels.videos) - 1) self.loadVideo(self.labels.videos[new_idx], new_idx) - def loadVideo(self, video:Video, video_idx: int = None): - # Clear video frame mark - self.mark_idx = None + def loadVideo(self, video: Video, video_idx: int = None): + """Activates video in gui.""" # Update current video instance self.video = video - self.video_idx = video_idx if video_idx is not None else self.labels.videos.index(video) + self.video_idx = ( + video_idx if video_idx is not None else self.labels.videos.index(video) + ) # Load video in player widget self.player.load_video(self.video) @@ -645,10 +1019,18 @@ def loadVideo(self, video:Video, video_idx: int = None): self.plotFrame(last_label.frame_idx) def openSkeleton(self): + """Shows gui for loading saved skeleton into project.""" filters = ["JSON skeleton (*.json)", "HDF5 skeleton (*.h5 *.hdf5)"] - filename, selected_filter = QFileDialog.getOpenFileName(self, dir=None, caption="Open skeleton...", filter=";;".join(filters)) - - if len(filename) == 0: return + filename, selected_filter = openFileDialog( + self, + dir=None, + caption="Open skeleton...", + filter=";;".join(filters), + options=self._file_dialog_options, + ) + + if len(filename) == 0: + return if filename.endswith(".json"): self.skeleton = Skeleton.load_json(filename) @@ -657,15 +1039,27 @@ def openSkeleton(self): if len(sk_list): self.skeleton = sk_list[0] + if self.skeleton not in self.labels: + self.labels.skeletons.append(self.skeleton) + self.changestack_push("new skeleton") + # Update data model - self.update_data_views() + self._update_data_views() def saveSkeleton(self): + """Shows gui for saving skeleton from project.""" default_name = "skeleton.json" filters = ["JSON skeleton (*.json)", "HDF5 skeleton (*.h5 *.hdf5)"] - filename, selected_filter = QFileDialog.getSaveFileName(self, caption="Save As...", dir=default_name, filter=";;".join(filters)) - - if len(filename) == 0: return + filename, selected_filter = saveFileDialog( + self, + caption="Save As...", + dir=default_name, + filter=";;".join(filters), + options=self._file_dialog_options, + ) + + if len(filename) == 0: + return if filename.endswith(".json"): self.skeleton.save_json(filename) @@ -673,6 +1067,7 @@ def saveSkeleton(self): self.skeleton.save_hdf5(filename) def newNode(self): + """Adds new node to skeleton.""" # Find new part name part_name = "new_part" i = 1 @@ -685,14 +1080,16 @@ def newNode(self): self.changestack_push("new node") # Update data model - self.update_data_views() + self._update_data_views() self.plotFrame() def deleteNode(self): + """Removes (currently selected) node from skeleton.""" # Get selected node idx = self.skeletonNodesTable.currentIndex() - if not idx.isValid(): return + if not idx.isValid(): + return node = self.skeleton.nodes[idx.row()] # Remove @@ -700,19 +1097,18 @@ def deleteNode(self): self.changestack_push("delete node") # Update data model - self.update_data_views() + self._update_data_views() # Replot instances self.plotFrame() - def selectSkeletonEdgeSrc(self): - self.skeletonEdgesDst.model().skeleton = self.skeleton - def updateEdges(self): - self.update_data_views() + """Called when edges in skeleton have been changed.""" + self._update_data_views() self.plotFrame() def newEdge(self): + """Adds new edge to skeleton.""" # TODO: Move this to unified data model # Get selected nodes @@ -728,17 +1124,18 @@ def newEdge(self): self.changestack_push("new edge") # Update data model - self.update_data_views() + self._update_data_views() self.plotFrame() - def deleteEdge(self): + """Removes (currently selected) edge from skeleton.""" # TODO: Move this to unified data model # Get selected edge idx = self.skeletonEdgesTable.currentIndex() - if not idx.isValid(): return + if not idx.isValid(): + return edge = self.skeleton.edges[idx.row()] # Delete edge @@ -746,75 +1143,135 @@ def deleteEdge(self): self.changestack_push("delete edge") # Update data model - self.update_data_views() + self._update_data_views() self.plotFrame() def updateSeekbarMarks(self): + """Updates marks on seekbar.""" self.player.seekbar.setTracksFromLabels(self.labels, self.video) - def generateSuggestions(self, params): + def setSeekbarHeader(self, graph_name): + """Updates graph shown in seekbar header.""" + data_obj = StatisticSeries(self.labels) + header_functions = { + "Point Displacement (sum)": data_obj.get_point_displacement_series, + "Point Displacement (max)": data_obj.get_point_displacement_series, + "Instance Score (sum)": data_obj.get_instance_score_series, + "Instance Score (min)": data_obj.get_instance_score_series, + "Point Score (sum)": data_obj.get_point_score_series, + "Point Score (min)": data_obj.get_point_score_series, + "Number of predicted points": data_obj.get_point_count_series, + } + + self._menu_check_single(self.seekbarHeaderMenu, graph_name) + + if graph_name == "None": + self.player.seekbar.clearHeader() + else: + if graph_name in header_functions: + kwargs = dict(video=self.video) + reduction_name = re.search("\((sum|max|min)\)", graph_name) + if reduction_name is not None: + kwargs["reduction"] = reduction_name.group(1) + series = header_functions[graph_name](**kwargs) + self.player.seekbar.setHeaderSeries(series) + else: + print(f"Could not find function for {header_functions}") + + def generateSuggestions(self, params: Dict): + """Generates suggestions using given params dictionary.""" new_suggestions = dict() for video in self.labels.videos: new_suggestions[video] = VideoFrameSuggestions.suggest( - video=video, - labels=self.labels, - params=params) + video=video, labels=self.labels, params=params + ) self.labels.set_suggestions(new_suggestions) - self.update_data_views() + + self._update_data_views("suggestions") self.updateSeekbarMarks() def _frames_for_prediction(self): - - def remove_user_labeled(video, frames, user_labeled_frames=self.labels.user_labeled_frames): - if len(frames) == 0: return frames - video_user_labeled_frame_idxs = [lf.frame_idx for lf in user_labeled_frames - if lf.video == video] + """Builds options for frames on which to run inference. + + Args: + None. + Returns: + Dictionary, keys are names of options (e.g., "clip", "random"), + values are {video: list of frame indices} dictionaries. + """ + + def remove_user_labeled( + video, frames, user_labeled_frames=self.labels.user_labeled_frames + ): + if len(frames) == 0: + return frames + video_user_labeled_frame_idxs = [ + lf.frame_idx for lf in user_labeled_frames if lf.video == video + ] return list(set(frames) - set(video_user_labeled_frame_idxs)) selection = dict() selection["frame"] = {self.video: [self.player.frame_idx]} - selection["clip"] = {self.video: list(range(*self.player.seekbar.getSelection()))} + selection["clip"] = { + self.video: list(range(*self.player.seekbar.getSelection())) + } selection["video"] = {self.video: list(range(self.video.num_frames))} selection["suggestions"] = { - video:remove_user_labeled(video, self.labels.get_video_suggestions(video)) - for video in self.labels.videos} + video: remove_user_labeled(video, self.labels.get_video_suggestions(video)) + for video in self.labels.videos + } selection["random"] = { video: remove_user_labeled(video, VideoFrameSuggestions.random(video=video)) - for video in self.labels.videos} + for video in self.labels.videos + } return selection - def _show_learning_window(self, mode): + def showLearningDialog(self, mode: str): + """Helper function to show active learning dialog in given mode. + + Args: + mode: A string representing mode for dialog, which could be: + * "active" + * "inference" + * "expert" + + Returns: + None. + """ from sleap.gui.active import ActiveLearningDialog + if "inference" in self.overlays: + QMessageBox( + text="In order to use this function you must first quit and " + "re-open SLEAP to release resources used by visualizing " + "model outputs." + ).exec_() + return + if self._child_windows.get(mode, None) is None: - self._child_windows[mode] = ActiveLearningDialog(self.filename, self.labels, mode) + self._child_windows[mode] = ActiveLearningDialog( + self.filename, self.labels, mode + ) self._child_windows[mode].learningFinished.connect(self.learningFinished) self._child_windows[mode].frame_selection = self._frames_for_prediction() self._child_windows[mode].open() def learningFinished(self): + """Called when active learning (or inference) finishes.""" # we ran active learning so update display/ui self.plotFrame() self.updateSeekbarMarks() - self.update_data_views() + self._update_data_views() self.changestack_push("new predictions") - def runLearningExpert(self): - self._show_learning_window("expert") - - def runInference(self): - self._show_learning_window("inference") - - def runActiveLearning(self): - self._show_learning_window("learning") - def visualizeOutputs(self): + """Gui for adding overlay with live visualization of predictions.""" filters = ["Model (*.json)", "HDF5 output (*.h5 *.hdf5)"] # Default to opening from models directory from project @@ -823,15 +1280,23 @@ def visualizeOutputs(self): models_dir = os.path.join(os.path.dirname(self.filename), "models/") # Show dialog - filename, selected_filter = QFileDialog.getOpenFileName(self, dir=models_dir, caption="Import model outputs...", filter=";;".join(filters)) - - if len(filename) == 0: return + filename, selected_filter = openFileDialog( + self, + dir=models_dir, + caption="Import model outputs...", + filter=";;".join(filters), + options=self._file_dialog_options, + ) + + if len(filename) == 0: + return if selected_filter == filters[0]: # Model as overlay datasource # This will show live inference results from sleap.gui.overlays.base import DataOverlay + overlay = DataOverlay.from_model(filename, self.video, player=self.player) self.overlays["inference"] = overlay @@ -845,41 +1310,45 @@ def visualizeOutputs(self): if show_confmaps: from sleap.gui.overlays.confmaps import ConfmapOverlay + confmap_overlay = ConfmapOverlay.from_h5(filename, player=self.player) - self.player.changedPlot.connect(lambda parent, idx: confmap_overlay.add_to_scene(None, idx)) + self.player.changedPlot.connect( + lambda parent, idx: confmap_overlay.add_to_scene(None, idx) + ) if show_pafs: from sleap.gui.overlays.pafs import PafOverlay + paf_overlay = PafOverlay.from_h5(filename, player=self.player) - self.player.changedPlot.connect(lambda parent, idx: paf_overlay.add_to_scene(None, idx)) + self.player.changedPlot.connect( + lambda parent, idx: paf_overlay.add_to_scene(None, idx) + ) self.plotFrame() def deletePredictions(self): + """Deletes all predicted instances in project.""" - predicted_instances = [(lf, inst) for lf in self.labels for inst in lf if type(inst) == PredictedInstance] - - resp = QMessageBox.critical(self, - "Removing predicted instances", - f"There are {len(predicted_instances)} predicted instances. " - "Are you sure you want to delete these?", - QMessageBox.Yes, QMessageBox.No) - - if resp == QMessageBox.No: return + predicted_instances = [ + (lf, inst) + for lf in self.labels + for inst in lf + if type(inst) == PredictedInstance + ] - for lf, inst in predicted_instances: - self.labels.remove_instance(lf, inst) - - self.plotFrame() - self.updateSeekbarMarks() - self.changestack_push("removed predictions") + self._delete_confirm(predicted_instances) def deleteClipPredictions(self): + """Deletes all instances within selected range of video frames.""" - predicted_instances = [(lf, inst) - for lf in self.labels.find(self.video, frame_idx = range(*self.player.seekbar.getSelection())) - for inst in lf - if type(inst) == PredictedInstance] + predicted_instances = [ + (lf, inst) + for lf in self.labels.find( + self.video, frame_idx=range(*self.player.seekbar.getSelection()) + ) + for inst in lf + if type(inst) == PredictedInstance + ] # If user selected an instance, then only delete for that track. selected_inst = self.player.view.getSelectionInstance() @@ -891,25 +1360,14 @@ def deleteClipPredictions(self): predicted_instances = [(self.labeled_frame, selected_inst)] else: # Filter by track - predicted_instances = list(filter(lambda x: x[1].track == track, predicted_instances)) - - resp = QMessageBox.critical(self, - "Removing predicted instances", - f"There are {len(predicted_instances)} predicted instances. " - "Are you sure you want to delete these?", - QMessageBox.Yes, QMessageBox.No) + predicted_instances = list( + filter(lambda x: x[1].track == track, predicted_instances) + ) - if resp == QMessageBox.No: return - - # Delete the instances - for lf, inst in predicted_instances: - self.labels.remove_instance(lf, inst) - - self.plotFrame() - self.updateSeekbarMarks() - self.changestack_push("removed predictions") + self._delete_confirm(predicted_instances) def deleteAreaPredictions(self): + """Gui for deleting instances within some rect on frame images.""" # Callback to delete after area has been selected def delete_area_callback(x0, y0, x1, y1): @@ -917,13 +1375,14 @@ def delete_area_callback(x0, y0, x1, y1): self.updateStatusMessage() # Make sure there was an area selected - if x0==x1 or y0==y1: return + if x0 == x1 or y0 == y1: + return min_corner = (x0, y0) max_corner = (x1, y1) def is_bounded(inst): - points_array = inst.points_array(invisible_as_nan=True) + points_array = inst.points_array valid_points = points_array[~np.isnan(points_array).any(axis=1)] is_gt_min = np.all(valid_points >= min_corner) @@ -931,63 +1390,79 @@ def is_bounded(inst): return is_gt_min and is_lt_max # Find all instances contained in selected area - predicted_instances = [(lf, inst) for lf in self.labels.find(self.video) - for inst in lf - if type(inst) == PredictedInstance - and is_bounded(inst)] + predicted_instances = [ + (lf, inst) + for lf in self.labels.find(self.video) + for inst in lf + if type(inst) == PredictedInstance and is_bounded(inst) + ] self._delete_confirm(predicted_instances) # Prompt the user to select area - self.updateStatusMessage(f"Please select the area from which to remove instances. This will be applied to all frames.") + self.updateStatusMessage( + f"Please select the area from which to remove instances. This will be applied to all frames." + ) self.player.onAreaSelection(delete_area_callback) def deleteLowScorePredictions(self): + """Gui for deleting instances below some score threshold.""" score_thresh, okay = QtWidgets.QInputDialog.getDouble( - self, - "Delete Instances with Low Score...", - "Score Below:", - 1, - 0, 100) + self, "Delete Instances with Low Score...", "Score Below:", 1, 0, 100 + ) if okay: # Find all instances contained in selected area - predicted_instances = [(lf, inst) for lf in self.labels.find(self.video) - for inst in lf - if type(inst) == PredictedInstance - and inst.score < score_thresh] + predicted_instances = [ + (lf, inst) + for lf in self.labels.find(self.video) + for inst in lf + if type(inst) == PredictedInstance and inst.score < score_thresh + ] self._delete_confirm(predicted_instances) def deleteFrameLimitPredictions(self): + """Gui for deleting instances beyond some number in each frame.""" count_thresh, okay = QtWidgets.QInputDialog.getInt( - self, - "Limit Instances in Frame...", - "Maximum instances in a frame:", - 3, - 1, 100) + self, + "Limit Instances in Frame...", + "Maximum instances in a frame:", + 3, + 1, + 100, + ) if okay: predicted_instances = [] # Find all instances contained in selected area for lf in self.labels.find(self.video): if len(lf.instances) > count_thresh: # Get all but the count_thresh many instances with the highest score - extra_instances = sorted(lf.instances, - key=operator.attrgetter('score') - )[:-count_thresh] + extra_instances = sorted( + lf.instances, key=operator.attrgetter("score") + )[:-count_thresh] predicted_instances.extend([(lf, inst) for inst in extra_instances]) self._delete_confirm(predicted_instances) def _delete_confirm(self, lf_inst_list): + """Helper function to confirm before deleting instances. - # Confirm that we want to delete - resp = QMessageBox.critical(self, - "Removing predicted instances", - f"There are {len(lf_inst_list)} predicted instances that would be deleted. " - "Are you sure you want to delete these?", - QMessageBox.Yes, QMessageBox.No) + Args: + lf_inst_list: A list of (labeled frame, instance) tuples. + """ - if resp == QMessageBox.No: return + # Confirm that we want to delete + resp = QMessageBox.critical( + self, + "Removing predicted instances", + f"There are {len(lf_inst_list)} predicted instances that would be deleted. " + "Are you sure you want to delete these?", + QMessageBox.Yes, + QMessageBox.No, + ) + + if resp == QMessageBox.No: + return # Delete the instances for lf, inst in lf_inst_list: @@ -999,135 +1474,189 @@ def _delete_confirm(self, lf_inst_list): self.changestack_push("removed predictions") def markNegativeAnchor(self): + """Allows user to add negative training sample anchor.""" + def click_callback(x, y): self.updateStatusMessage() self.labels.add_negative_anchor(self.video, self.player.frame_idx, (x, y)) self.changestack_push("add negative anchors") + self.plotFrame() # Prompt the user to select area self.updateStatusMessage(f"Please click where you want a negative sample...") self.player.onPointSelection(click_callback) + def clearFrameNegativeAnchors(self): + """Removes negative training sample anchors on current frame.""" + self.labels.remove_negative_anchors(self.video, self.player.frame_idx) + self.changestack_push("remove negative anchors") + self.plotFrame() + def importPredictions(self): + """Starts gui for importing another dataset into currently one.""" filters = ["HDF5 dataset (*.h5 *.hdf5)", "JSON labels (*.json *.json.zip)"] - filenames, selected_filter = QFileDialog.getOpenFileNames(self, dir=None, caption="Import labeled data...", filter=";;".join(filters)) - - if len(filenames) == 0: return + filenames, selected_filter = openFileDialogs( + self, + dir=None, + caption="Import labeled data...", + filter=";;".join(filters), + options=self._file_dialog_options, + ) + + if len(filenames) == 0: + return for filename in filenames: gui_video_callback = Labels.make_gui_video_callback( - search_paths=[os.path.dirname(filename)]) - - if filename.endswith((".h5", ".hdf5")): - new_labels = Labels.load_hdf5( - filename, - match_to=self.labels, - video_callback=gui_video_callback) + search_paths=[os.path.dirname(filename)] + ) - elif filename.endswith((".json", ".json.zip")): - new_labels = Labels.load_json( - filename, - match_to=self.labels, - video_callback=gui_video_callback) + new_labels = Labels.load_file(filename, video_callback=gui_video_callback) - self.labels.extend_from(new_labels) - - for vid in new_labels.videos: - print(f"Labels imported for {vid.filename}") - print(f" frames labeled: {len(new_labels.find(vid))}") + # Merging data is handled by MergeDialog + MergeDialog(base_labels=self.labels, new_labels=new_labels).exec_() # update display/ui self.plotFrame() self.updateSeekbarMarks() - self.update_data_views() + self._update_data_views() self.changestack_push("new predictions") - def doubleClickInstance(self, instance): + def doubleClickInstance(self, instance: Instance): + """ + Handles when the user has double-clicked an instance. + + If prediction, then copy to new user-instance. + If already user instance, then add any missing nodes (in case + skeleton has been changed after instance was created). + + Args: + instance: The :class:`Instance` that was double-clicked. + """ # When a predicted instance is double-clicked, add a new instance if hasattr(instance, "score"): - self.newInstance(copy_instance = instance) + self.newInstance(copy_instance=instance) # When a regular instance is double-clicked, add any missing points else: # the rect that's currently visibile in the window view - in_view_rect = self.player.view.mapToScene(self.player.view.rect()).boundingRect() + in_view_rect = self.player.view.mapToScene( + self.player.view.rect() + ).boundingRect() for node in self.skeleton.nodes: - if node not in instance.nodes or instance[node].isnan(): + if node.name not in instance.node_names or instance[node].isnan(): # pick random points within currently zoomed view - x = in_view_rect.x() + (in_view_rect.width() * 0.1) \ + x = ( + in_view_rect.x() + + (in_view_rect.width() * 0.1) + (np.random.rand() * in_view_rect.width() * 0.8) - y = in_view_rect.y() + (in_view_rect.height() * 0.1) \ + ) + y = ( + in_view_rect.y() + + (in_view_rect.height() * 0.1) + (np.random.rand() * in_view_rect.height() * 0.8) + ) # set point for node instance[node] = Point(x=x, y=y, visible=False) self.plotFrame() - def newInstance(self, copy_instance=None): + def newInstance(self, copy_instance: Optional[Instance] = None): + """ + Creates a new instance, copying node coordinates as appropriate. + + Args: + copy_instance: The :class:`Instance` (or + :class:`PredictedInstance`) which we want to copy. + """ if self.labeled_frame is None: return # FIXME: filter by skeleton type from_predicted = copy_instance - unused_predictions = self.labeled_frame.unused_predictions - from_prev_frame = False + if copy_instance is None: selected_idx = self.player.view.getSelection() if selected_idx is not None: # If the user has selected an instance, copy that one. copy_instance = self.labeled_frame.instances[selected_idx] from_predicted = copy_instance - elif len(unused_predictions): + + if copy_instance is None: + unused_predictions = self.labeled_frame.unused_predictions + if len(unused_predictions): # If there are predicted instances that don't correspond to an instance # in this frame, use the first predicted instance without matching instance. copy_instance = unused_predictions[0] from_predicted = copy_instance - else: - # Otherwise, if there are instances in previous frames, - # copy the points from one of those instances. - prev_idx = self.previousLabeledFrameIndex() - if prev_idx is not None: - prev_instances = self.labels.find(self.video, prev_idx, return_new=True)[0].instances - if len(prev_instances) > len(self.labeled_frame.instances): - # If more instances in previous frame than current, then use the - # first unmatched instance. - copy_instance = prev_instances[len(self.labeled_frame.instances)] - from_prev_frame = True - elif len(self.labeled_frame.instances): - # Otherwise, if there are already instances in current frame, - # copy the points from the last instance added to frame. - copy_instance = self.labeled_frame.instances[-1] - elif len(prev_instances): - # Otherwise use the last instance added to previous frame. - copy_instance = prev_instances[-1] - from_prev_frame = True + + if copy_instance is None: + # Otherwise, if there are instances in previous frames, + # copy the points from one of those instances. + prev_idx = self.previousLabeledFrameIndex() + + if prev_idx is not None: + prev_instances = self.labels.find( + self.video, prev_idx, return_new=True + )[0].instances + if len(prev_instances) > len(self.labeled_frame.instances): + # If more instances in previous frame than current, then use the + # first unmatched instance. + copy_instance = prev_instances[len(self.labeled_frame.instances)] + from_prev_frame = True + elif len(self.labeled_frame.instances): + # Otherwise, if there are already instances in current frame, + # copy the points from the last instance added to frame. + copy_instance = self.labeled_frame.instances[-1] + elif len(prev_instances): + # Otherwise use the last instance added to previous frame. + copy_instance = prev_instances[-1] + from_prev_frame = True + from_predicted = from_predicted if hasattr(from_predicted, "score") else None + + # Now create the new instance new_instance = Instance(skeleton=self.skeleton, from_predicted=from_predicted) - # the rect that's currently visibile in the window view - in_view_rect = self.player.view.mapToScene(self.player.view.rect()).boundingRect() + # Get the rect that's currently visibile in the window view + in_view_rect = self.player.view.mapToScene( + self.player.view.rect() + ).boundingRect() # go through each node in skeleton - for node in self.skeleton.nodes: + for node in self.skeleton.node_names: # if we're copying from a skeleton that has this node - if copy_instance is not None and node in copy_instance.nodes and not copy_instance[node].isnan(): + if ( + copy_instance is not None + and node in copy_instance + and not copy_instance[node].isnan() + ): # just copy x, y, and visible # we don't want to copy a PredictedPoint or score attribute new_instance[node] = Point( - x=copy_instance[node].x, - y=copy_instance[node].y, - visible=copy_instance[node].visible) + x=copy_instance[node].x, + y=copy_instance[node].y, + visible=copy_instance[node].visible, + ) else: # pick random points within currently zoomed view - x = in_view_rect.x() + (in_view_rect.width() * 0.1) \ + x = ( + in_view_rect.x() + + (in_view_rect.width() * 0.1) + (np.random.rand() * in_view_rect.width() * 0.8) - y = in_view_rect.y() + (in_view_rect.height() * 0.1) \ + ) + y = ( + in_view_rect.y() + + (in_view_rect.height() * 0.1) + (np.random.rand() * in_view_rect.height() * 0.8) + ) # mark the node as not "visible" if we're copying from a predicted instance without this node - is_visible = copy_instance is None or not hasattr(copy_instance, "score") + is_visible = copy_instance is None or not hasattr( + copy_instance, "score" + ) # set point for node new_instance[node] = Point(x=x, y=y, visible=is_visible) @@ -1149,8 +1678,10 @@ def newInstance(self, copy_instance=None): self.updateTrackMenu() def deleteSelectedInstance(self): + """Deletes currently selected instance.""" selected_inst = self.player.view.getSelectionInstance() - if selected_inst is None: return + if selected_inst is None: + return self.labels.remove_instance(self.labeled_frame, selected_inst) self.changestack_push("delete instance") @@ -1159,8 +1690,10 @@ def deleteSelectedInstance(self): self.updateSeekbarMarks() def deleteSelectedInstanceTrack(self): + """Deletes all instances from track of currently selected instance.""" selected_inst = self.player.view.getSelectionInstance() - if selected_inst is None: return + if selected_inst is None: + return # to do: range of frames? @@ -1180,9 +1713,10 @@ def deleteSelectedInstanceTrack(self): self.updateSeekbarMarks() def addTrack(self): - track_numbers_used = [int(track.name) - for track in self.labels.tracks - if track.name.isnumeric()] + """Creates new track and moves selected instance into this track.""" + track_numbers_used = [ + int(track.name) for track in self.labels.tracks if track.name.isnumeric() + ] next_number = max(track_numbers_used, default=0) + 1 new_track = Track(spawned_on=self.player.frame_idx, name=next_number) @@ -1197,9 +1731,11 @@ def addTrack(self): self.updateTrackMenu() self.updateSeekbarMarks() - def setInstanceTrack(self, new_track): + def setInstanceTrack(self, new_track: "Track"): + """Sets track for selected instance.""" vis_idx = self.player.view.getSelection() - if vis_idx is None: return + if vis_idx is None: + return selected_instance = self.labeled_frame.instances_to_show[vis_idx] idx = self.labeled_frame.index(selected_instance) @@ -1211,13 +1747,16 @@ def setInstanceTrack(self, new_track): if old_track is None: # Move anything already in the new track out of it new_track_instances = self.labels.find_track_instances( - video = self.video, - track = new_track, - frame_range = (self.player.frame_idx, self.player.frame_idx+1)) + video=self.video, + track=new_track, + frame_range=(self.player.frame_idx, self.player.frame_idx + 1), + ) for instance in new_track_instances: instance.track = None # Move selected instance into new track - self.labels.track_set_instance(self.labeled_frame, selected_instance, new_track) + self.labels.track_set_instance( + self.labeled_frame, selected_instance, new_track + ) # When the instance does already have a track, then we want to update # the track for a range of frames. @@ -1244,29 +1783,38 @@ def setInstanceTrack(self, new_track): self.player.view.selectInstance(idx) def transposeInstance(self): + """Transposes tracks for two instances. + + If there are only two instances, then this swaps tracks. + Otherwise, it allows user to select the instances for which we want + to swap tracks. + """ # We're currently identifying instances by numeric index, so it's # impossible to (e.g.) have a single instance which we identify # as the second instance in some other frame. # For the present, we can only "transpose" if there are multiple instances. - if len(self.labeled_frame.instances) < 2: return + if len(self.labeled_frame.instances) < 2: + return # If there are just two instances, transpose them. if len(self.labeled_frame.instances) == 2: - self._transpose_instances((0,1)) + self._transpose_instances((0, 1)) # If there are more than two, then we need the user to select the instances. else: - self.player.onSequenceSelect(seq_len = 2, - on_success = self._transpose_instances, - on_each = self._transpose_message, - on_failure = lambda x:self.updateStatusMessage() - ) - - def _transpose_message(self, instance_ids:list): + self.player.onSequenceSelect( + seq_len=2, + on_success=self._transpose_instances, + on_each=self._transpose_message, + on_failure=lambda x: self.updateStatusMessage(), + ) + + def _transpose_message(self, instance_ids: list): word = "next" if len(instance_ids) else "first" self.updateStatusMessage(f"Please select the {word} instance to transpose...") - def _transpose_instances(self, instance_ids:list): - if len(instance_ids) != 2: return + def _transpose_instances(self, instance_ids: list): + if len(instance_ids) != 2: + return idx_0 = instance_ids[0] idx_1 = instance_ids[1] @@ -1288,73 +1836,102 @@ def _transpose_instances(self, instance_ids:list): self.updateSeekbarMarks() def newProject(self): + """Create a new project in a new window.""" window = MainWindow() window.showMaximized() - def openProject(self, first_open=False): - filters = ["JSON labels (*.json *.json.zip)", "HDF5 dataset (*.h5 *.hdf5)", "Matlab dataset (*.mat)", "DeepLabCut csv (*.csv)"] - filename, selected_filter = QFileDialog.getOpenFileName(self, dir=None, caption="Import labeled data...", filter=";;".join(filters)) - - if len(filename) == 0: return + def openProject(self, first_open: bool = False): + """ + Allows use to select and then open a saved project. + + Args: + first_open: Whether this is the first window opened. If True, + then the new project is loaded into the current window + rather than a new application window. + + Returns: + None. + """ + filters = [ + "HDF5 dataset (*.h5 *.hdf5)", + "JSON labels (*.json *.json.zip)", + "Matlab dataset (*.mat)", + "DeepLabCut csv (*.csv)", + ] + + filename, selected_filter = openFileDialog( + self, + dir=None, + caption="Import labeled data...", + filter=";;".join(filters), + # options=self._file_dialog_options, + ) + + if len(filename) == 0: + return if OPEN_IN_NEW and not first_open: new_window = MainWindow() new_window.showMaximized() - new_window.importData(filename) + new_window.loadProject(filename) else: - self.importData(filename) + self.loadProject(filename) def saveProject(self): + """Show gui to save project (or save as if not yet saved).""" if self.filename is not None: - filename = self.filename - - if filename.endswith((".json", ".json.zip")): - compress = filename.endswith(".zip") - Labels.save_json(labels = self.labels, filename = filename, - compress = compress) - elif filename.endswith(".h5"): - Labels.save_hdf5(labels = self.labels, filename = filename) - - # Mark savepoint in change stack - self.changestack_savepoint() - # Redraw. Not sure why, but sometimes we need to do this. - self.plotFrame() + self._trySave(self.filename) else: # No filename (must be new project), so treat as "Save as" self.saveProjectAs() def saveProjectAs(self): - default_name = self.filename if self.filename is not None else "untitled.json" + """Show gui to save project as a new file.""" + default_name = self.filename if self.filename is not None else "untitled" p = PurePath(default_name) default_name = str(p.with_name(f"{p.stem} copy{p.suffix}")) - filters = ["JSON labels (*.json)", "Compressed JSON (*.zip)", "HDF5 dataset (*.h5)"] - filename, selected_filter = QFileDialog.getSaveFileName(self, - caption="Save As...", - dir=default_name, - filter=";;".join(filters)) - - if len(filename) == 0: return + filters = [ + "HDF5 dataset (*.h5)", + "JSON labels (*.json)", + "Compressed JSON (*.zip)", + ] + filename, selected_filter = saveFileDialog( + self, + caption="Save As...", + dir=default_name, + filter=";;".join(filters), + options=self._file_dialog_options, + ) + + if len(filename) == 0: + return - if filename.endswith((".json", ".zip")): - compress = filename.endswith(".zip") - Labels.save_json(labels = self.labels, filename = filename, compress = compress) - self.filename = filename - # Mark savepoint in change stack - self.changestack_savepoint() - # Redraw. Not sure why, but sometimes we need to do this. - self.plotFrame() - elif filename.endswith(".h5"): - Labels.save_hdf5(labels = self.labels, filename = filename) + if self._trySave(filename): + # If save was successful self.filename = filename + + def _trySave(self, filename): + """Helper function which attempts save and handles errors.""" + success = False + try: + Labels.save_file(labels=self.labels, filename=filename) + success = True # Mark savepoint in change stack self.changestack_savepoint() - # Redraw. Not sure why, but sometimes we need to do this. - self.plotFrame() - else: - QMessageBox(text=f"File not saved. Try saving as json.").exec_() + + except Exception as e: + message = f"An error occured when attempting to save:\n {e}\n\n" + message += "Try saving your project with a different filename or in a different format." + QtWidgets.QMessageBox(text=message).exec_() + + # Redraw. Not sure why, but sometimes we need to do this. + self.plotFrame() + + return success def closeEvent(self, event): + """Closes application window, prompting for saving as needed.""" if not self.changestack_has_changes(): # No unsaved changes, so accept event (close) event.accept() @@ -1362,7 +1939,9 @@ def closeEvent(self, event): msgBox = QMessageBox() msgBox.setText("Do you want to save the changes to this project?") msgBox.setInformativeText("If you don't save, your changes will be lost.") - msgBox.setStandardButtons(QMessageBox.Save | QMessageBox.Discard | QMessageBox.Cancel) + msgBox.setStandardButtons( + QMessageBox.Save | QMessageBox.Discard | QMessageBox.Cancel + ) msgBox.setDefaultButton(QMessageBox.Save) ret_val = msgBox.exec_() @@ -1380,150 +1959,192 @@ def closeEvent(self, event): event.accept() def nextVideo(self): - new_idx = self.video_idx+1 + """Activates next video in project.""" + new_idx = self.video_idx + 1 new_idx = 0 if new_idx >= len(self.labels.videos) else new_idx self.loadVideo(self.labels.videos[new_idx], new_idx) def previousVideo(self): - new_idx = self.video_idx-1 - new_idx = len(self.labels.videos)-1 if new_idx < 0 else new_idx + """Activates previous video in project.""" + new_idx = self.video_idx - 1 + new_idx = len(self.labels.videos) - 1 if new_idx < 0 else new_idx self.loadVideo(self.labels.videos[new_idx], new_idx) def gotoFrame(self): + """Shows gui to go to frame by number.""" frame_number, okay = QtWidgets.QInputDialog.getInt( - self, - "Go To Frame...", - "Frame Number:", - self.player.frame_idx+1, - 1, self.video.frames) + self, + "Go To Frame...", + "Frame Number:", + self.player.frame_idx + 1, + 1, + self.video.frames, + ) if okay: - self.plotFrame(frame_number-1) - - def markFrame(self): - self.mark_idx = self.player.frame_idx - - def goMarkedFrame(self): - self.plotFrame(self.mark_idx) + self.plotFrame(frame_number - 1) def exportLabeledClip(self): + """Shows gui for exporting clip with visual annotations.""" from sleap.io.visuals import save_labeled_video + if self.player.seekbar.hasSelection(): fps, okay = QtWidgets.QInputDialog.getInt( - self, - "Frames per second", - "Frames per second:", - getattr(self.video, "fps", 30), - 1, 300) - if not okay: return + self, + "Frames per second", + "Frames per second:", + getattr(self.video, "fps", 30), + 1, + 300, + ) + if not okay: + return - filename, _ = QFileDialog.getSaveFileName(self, caption="Save Video As...", dir=self.filename + ".avi", filter="AVI Video (*.avi)") + filename, _ = saveFileDialog( + self, + caption="Save Video As...", + dir=self.filename + ".avi", + filter="AVI Video (*.avi)", + options=self._file_dialog_options, + ) - if len(filename) == 0: return + if len(filename) == 0: + return save_labeled_video( - labels=self.labels, - video=self.video, - filename=filename, - frames=list(range(*self.player.seekbar.getSelection())), - fps=fps, - gui_progress=True - ) + labels=self.labels, + video=self.video, + filename=filename, + frames=list(range(*self.player.seekbar.getSelection())), + fps=fps, + gui_progress=True, + ) def exportLabeledFrames(self): - filename, _ = QFileDialog.getSaveFileName(self, caption="Save Labeled Frames As...", dir=self.filename) - if len(filename) == 0: return - Labels.save_json(self.labels, filename, save_frame_data=True) - - def previousLabeledFrameIndex(self): - cur_idx = self.player.frame_idx - frames = self.labels.frames(self.video, from_frame_idx=cur_idx, reverse=True) - - try: - next_idx = next(frames).frame_idx - except: + """Gui for exporting the training dataset of labels/frame images.""" + filters = ["HDF5 dataset (*.h5)", "Compressed JSON dataset (*.json *.json.zip)"] + filename, _ = saveFileDialog( + self, + caption="Save Labeled Frames As...", + dir=self.filename + ".h5", + filters=";;".join(filters), + options=self._file_dialog_options, + ) + if len(filename) == 0: return - def previousLabeledFrame(self): - prev_idx = self.previousLabeledFrameIndex() - if prev_idx is not None: - self.plotFrame(prev_idx) + Labels.save_file( + self.labels, filename, default_suffix="h5", save_frame_data=True + ) - def nextLabeledFrame(self): - cur_idx = self.player.frame_idx + def _plot_if_next(self, frame_iterator: Iterator) -> bool: + """Plots next frame (if there is one) from iterator. - frames = self.labels.frames(self.video, from_frame_idx=cur_idx) + Arguments: + frame_iterator: The iterator from which we'll try to get next + :class:`LabeledFrame`. + Returns: + True if we went to next frame. + """ try: - next_idx = next(frames).frame_idx - except: - return + next_lf = next(frame_iterator) + except StopIteration: + return False - self.plotFrame(next_idx) + self.plotFrame(next_lf.frame_idx) + return True - def nextUserLabeledFrame(self): - cur_idx = self.player.frame_idx + def previousLabeledFrame(self): + """Goes to labeled frame prior to current frame.""" + frames = self.labels.frames( + self.video, from_frame_idx=self.player.frame_idx, reverse=True + ) + self._plot_if_next(frames) + + def nextLabeledFrame(self): + """Goes to labeled frame after current frame.""" + frames = self.labels.frames(self.video, from_frame_idx=self.player.frame_idx) + self._plot_if_next(frames) - frames = self.labels.frames(self.video, from_frame_idx=cur_idx) + def nextUserLabeledFrame(self): + """Goes to next labeled frame with user instances.""" + frames = self.labels.frames(self.video, from_frame_idx=self.player.frame_idx) # Filter to frames with user instances frames = filter(lambda lf: lf.has_user_instances, frames) - - try: - next_idx = next(frames).frame_idx - except: - return - - self.plotFrame(next_idx) + self._plot_if_next(frames) def nextSuggestedFrame(self, seek_direction=1): - next_video, next_frame = self.labels.get_next_suggestion(self.video, self.player.frame_idx, seek_direction) + """Goes to next (or previous) suggested frame.""" + next_video, next_frame = self.labels.get_next_suggestion( + self.video, self.player.frame_idx, seek_direction + ) if next_video is not None: self.gotoVideoAndFrame(next_video, next_frame) if next_frame is not None: - selection_idx = self.labels.get_suggestions().index((next_video, next_frame)) + selection_idx = self.labels.get_suggestions().index( + (next_video, next_frame) + ) self.suggestionsTable.selectRow(selection_idx) def nextTrackFrame(self): + """Goes to next frame on which a track starts.""" cur_idx = self.player.frame_idx - video_tracks = {inst.track for lf in self.labels.find(self.video) for inst in lf if inst.track is not None} - next_idx = min([track.spawned_on for track in video_tracks if track.spawned_on > cur_idx], default=-1) + track_ranges = self.labels.get_track_occupany(self.video) + next_idx = min( + [ + track_range.start + for track_range in track_ranges.values() + if track_range.start is not None and track_range.start > cur_idx + ], + default=-1, + ) if next_idx > -1: self.plotFrame(next_idx) - def gotoVideoAndFrame(self, video, frame_idx): + def gotoVideoAndFrame(self, video: Video, frame_idx: int): + """Activates video and goes to frame.""" if video != self.video: # switch to the other video self.loadVideo(video) self.plotFrame(frame_idx) def toggleLabels(self): + """Toggles whether skeleton node labels are shown in video overlay.""" self._show_labels = not self._show_labels self._menu_actions["show labels"].setChecked(self._show_labels) self.player.showLabels(self._show_labels) def toggleEdges(self): + """Toggles whether skeleton edges are shown in video overlay.""" self._show_edges = not self._show_edges self._menu_actions["show edges"].setChecked(self._show_edges) self.player.showEdges(self._show_edges) def toggleTrails(self): + """Toggles whether track trails are shown in video overlay.""" self.overlays["trails"].show = not self.overlays["trails"].show self._menu_actions["show trails"].setChecked(self.overlays["trails"].show) self.plotFrame() - def setTrailLength(self, trail_length): + def setTrailLength(self, trail_length: int): + """Sets length of track trails to show in video overlay.""" self.overlays["trails"].trail_length = trail_length self._menu_check_single(self.trailLengthMenu, trail_length) - if self.video is not None: self.plotFrame() + if self.video is not None: + self.plotFrame() - def setPalette(self, palette): + def setPalette(self, palette: str): + """Sets color palette used for track colors.""" self._color_manager.set_palette(palette) self._menu_check_single(self.paletteMenu, palette) - if self.video is not None: self.plotFrame() + if self.video is not None: + self.plotFrame() self.updateSeekbarMarks() def _menu_check_single(self, menu, item_text): + """Helper method to select exactly one submenu item.""" for menu_item in menu.children(): if menu_item.text() == str(item_text): menu_item.setChecked(True) @@ -1531,70 +2152,76 @@ def _menu_check_single(self, menu, item_text): menu_item.setChecked(False) def toggleColorPredicted(self): - self.overlays["instance"].color_predicted = not self.overlays["instance"].color_predicted - self._menu_actions["color predicted"].setChecked(self.overlays["instance"].color_predicted) + """Toggles whether predicted instances are shown in track colors.""" + val = self.overlays["instance"].color_predicted + self.overlays["instance"].color_predicted = not val + self._menu_actions["color predicted"].setChecked( + self.overlays["instance"].color_predicted + ) self.plotFrame() def toggleAutoZoom(self): + """Toggles whether to zoom viewer to fit labeled instances.""" self._auto_zoom = not self._auto_zoom self._menu_actions["fit"].setChecked(self._auto_zoom) if not self._auto_zoom: self.player.view.clearZoom() self.plotFrame() - def openDocumentation(self): - pass def openKeyRef(self): - pass - def openAbout(self): - pass + """Shows gui for viewing/modifying keyboard shortucts.""" + ShortcutDialog().exec_() - def newFrame(self, player, frame_idx, selected_idx): - """Called each time a new frame is drawn.""" - - # Store the current LabeledFrame (or make new, empty object) - self.labeled_frame = self.labels.find(self.video, frame_idx, return_new=True)[0] - # Show instances, etc, for this frame - for overlay in self.overlays.values(): - overlay.add_to_scene(self.video, frame_idx) +def openFileDialog(*args, **kwargs): + """Wrapper for openFileDialog. - # Select instance if there was already selection - if selected_idx > -1: - player.view.selectInstance(selected_idx) + Passes along everything except empty "options" arg. + """ + if "options" in kwargs and not kwargs["options"]: + del kwargs["options"] + return QFileDialog.getOpenFileName(*args, **kwargs) - # Update related displays - self.updateStatusMessage() - self.update_data_views() - # Trigger event after the overlays have been added - player.view.updatedViewer.emit() +def saveFileDialog(*args, **kwargs): + """Wrapper for saveFileDialog. - def updateStatusMessage(self, message = None): - if message is None: - message = f"Frame: {self.player.frame_idx+1}/{len(self.video)}" - if self.player.seekbar.hasSelection(): - start, end = self.player.seekbar.getSelection() - message += f" (selection: {start}-{end})" + Passes along everything except empty "options" arg. + """ + if "options" in kwargs and not kwargs["options"]: + del kwargs["options"] + return QFileDialog.getSaveFileName(*args, **kwargs) - self.statusBar().showMessage(message) def main(*args, **kwargs): + """Starts new instance of app.""" app = QApplication([]) - app.setApplicationName("sLEAP Label") + app.setApplicationName("SLEAP Label") window = MainWindow(*args, **kwargs) window.showMaximized() - if "import_data" not in kwargs: + if not kwargs.get("labels_path", None): window.openProject(first_open=True) app.exec_() -if __name__ == "__main__": - - kwargs = dict() - if len(sys.argv) > 1: - kwargs["import_data"] = sys.argv[1] - main(**kwargs) +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "labels_path", help="Path to labels file", type=str, default=None, nargs="?" + ) + parser.add_argument( + "--nonnative", + help="Don't use native file dialogs", + action="store_const", + const=True, + default=False, + ) + + args = parser.parse_args() + + main(**vars(args)) diff --git a/sleap/gui/dataviews.py b/sleap/gui/dataviews.py index 8cbcdf52c..b9dd9b6d0 100644 --- a/sleap/gui/dataviews.py +++ b/sleap/gui/dataviews.py @@ -1,89 +1,107 @@ -from PySide2 import QtCore -from PySide2.QtCore import Qt +""" +Data table widgets and view models used in GUI app. +""" -from PySide2.QtGui import QKeyEvent, QColor - -from PySide2.QtWidgets import QApplication, QMainWindow, QWidget, QDockWidget -from PySide2.QtWidgets import QVBoxLayout, QHBoxLayout, QGroupBox, QFormLayout -from PySide2.QtWidgets import QLabel, QPushButton, QLineEdit, QSpinBox, QDoubleSpinBox, QComboBox, QCheckBox -from PySide2.QtWidgets import QTableWidget, QTableView, QTableWidgetItem, QAbstractItemView -from PySide2.QtWidgets import QTreeView, QTreeWidget, QTreeWidgetItem -from PySide2.QtWidgets import QMenu, QAction -from PySide2.QtWidgets import QFileDialog, QMessageBox +from PySide2 import QtCore, QtWidgets, QtGui import os -import numpy as np -import pandas as pd +from operator import itemgetter -from typing import Callable +from typing import Callable, List, Optional from sleap.gui.overlays.tracks import TrackColorManager -from sleap.io.video import Video from sleap.io.dataset import Labels -from sleap.instance import LabeledFrame -from sleap.skeleton import Skeleton, Node +from sleap.instance import LabeledFrame, Instance +from sleap.skeleton import Skeleton + +class VideosTable(QtWidgets.QTableView): + """Table view widget for listing videos in dataset.""" -class VideosTable(QTableView): - """Table view widget backed by a custom data model for displaying - lists of Video instances. """ def __init__(self, videos: list = []): super(VideosTable, self).__init__() - self.setModel(VideosTableModel(videos)) - self.setSelectionBehavior(QAbstractItemView.SelectRows) - self.setSelectionMode(QAbstractItemView.SingleSelection) -class VideosTableModel(QtCore.QAbstractTableModel): - _props = ["filename", "frames", "height", "width", "channels",] + props = ("filename", "frames", "height", "width", "channels") + model = GenericTableModel(props, videos, useCache=True) + + self.setModel(model) + + self.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows) + self.setSelectionMode(QtWidgets.QAbstractItemView.SingleSelection) + + +class GenericTableModel(QtCore.QAbstractTableModel): + """Generic table model to show a list of properties for some items. - def __init__(self, videos: list): - super(VideosTableModel, self).__init__() - self._videos = videos + Args: + propList: The list of property names (table columns). + itemList: The list of items with said properties (rows). + useCache: Whether to build cache of property values for all items. + """ + + def __init__( + self, + propList: List[str], + itemList: Optional[list] = None, + useCache: bool = False, + ): + super(GenericTableModel, self).__init__() + self._use_cache = useCache + self._props = propList + + if itemList is not None: + self.items = itemList + else: + self._data = [] @property - def videos(self): - return self._videos + def items(self): + """Gets or sets list of items to show in table.""" + return self._data - @videos.setter - def videos(self, val): + @items.setter + def items(self, val): self.beginResetModel() - self._videos = val + if self._use_cache: + self._data = [] + for item in val: + item_data = {key: getattr(item, key) for key in self._props} + self._data.append(item_data) + else: + self._data = val self.endResetModel() - def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): - if role == Qt.DisplayRole and index.isValid(): + def data(self, index: QtCore.QModelIndex, role=QtCore.Qt.DisplayRole): + """Overrides Qt method, returns data to show in table.""" + if role == QtCore.Qt.DisplayRole and index.isValid(): idx = index.row() - prop = self._props[index.column()] + key = self._props[index.column()] - if len(self.videos) > (idx - 1): - video = self.videos[idx] - - if prop == "filename": - # show parent dir + name - parent_dir = os.path.split(os.path.dirname(video.filename))[-1] - file_name = os.path.basename(video.filename) - trunc_name = os.path.join(parent_dir, file_name) - return trunc_name - elif prop == "frames": - return video.frames - elif prop == "height": - return video.height - elif prop == "width": - return video.width - elif prop == "channels": - return video.channels + if idx < self.rowCount(): + item = self.items[idx] + + if isinstance(item, dict) and key in item: + return item[key] + + if hasattr(item, key): + return getattr(item, key) return None - def rowCount(self, parent): - return len(self.videos) + def rowCount(self, parent=None): + """Overrides Qt method, returns number of rows (items).""" + return len(self._data) - def columnCount(self, parent): - return len(VideosTableModel._props) + def columnCount(self, parent=None): + """Overrides Qt method, returns number of columns (attributes).""" + return len(self._props) - def headerData(self, section, orientation: QtCore.Qt.Orientation, role=Qt.DisplayRole): - if role == Qt.DisplayRole: + def headerData( + self, section, orientation: QtCore.Qt.Orientation, role=QtCore.Qt.DisplayRole + ): + """Overrides Qt method, returns column (attribute) names.""" + if role == QtCore.Qt.DisplayRole: if orientation == QtCore.Qt.Horizontal: return self._props[section] elif orientation == QtCore.Qt.Vertical: @@ -91,20 +109,39 @@ def headerData(self, section, orientation: QtCore.Qt.Orientation, role=Qt.Displa return None + def sort(self, column_idx: int, order: QtCore.Qt.SortOrder): + """Sorts table by given column and order.""" + prop = self._props[column_idx] + + sort_function = itemgetter(prop) + if prop in ("video", "frame"): + if "video" in self._props and "frame" in self._props: + sort_function = itemgetter("video", "frame") + + reverse = order == QtCore.Qt.SortOrder.DescendingOrder + + self.beginResetModel() + self._data.sort(key=sort_function, reverse=reverse) + self.endResetModel() + def flags(self, index: QtCore.QModelIndex): - return Qt.ItemIsEnabled | Qt.ItemIsSelectable + """Overrides Qt method, returns whether item is selectable etc.""" + return QtCore.Qt.ItemIsEnabled | QtCore.Qt.ItemIsSelectable + +class SkeletonNodesTable(QtWidgets.QTableView): + """Table view widget for displaying and editing Skeleton nodes. """ -class SkeletonNodesTable(QTableView): - """Table view widget backed by a custom data model for displaying and - editing Skeleton nodes. """ def __init__(self, skeleton: Skeleton): super(SkeletonNodesTable, self).__init__() self.setModel(SkeletonNodesTableModel(skeleton)) - self.setSelectionBehavior(QAbstractItemView.SelectRows) - self.setSelectionMode(QAbstractItemView.SingleSelection) + self.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows) + self.setSelectionMode(QtWidgets.QAbstractItemView.SingleSelection) + class SkeletonNodesTableModel(QtCore.QAbstractTableModel): + """Table model for skeleton nodes.""" + _props = ["name", "symmetry"] def __init__(self, skeleton: Skeleton): @@ -113,6 +150,7 @@ def __init__(self, skeleton: Skeleton): @property def skeleton(self): + """Gets or sets current skeleton.""" return self._skeleton @skeleton.setter @@ -121,11 +159,12 @@ def skeleton(self, val): self._skeleton = val self.endResetModel() - def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): - if role == Qt.DisplayRole and index.isValid(): + def data(self, index: QtCore.QModelIndex, role=QtCore.Qt.DisplayRole): + """Overrides Qt method, returns data to show in table.""" + if role == QtCore.Qt.DisplayRole and index.isValid(): node_idx = index.row() prop = self._props[index.column()] - node = self.skeleton.nodes[node_idx] # FIXME? can we assume order is stable? + node = self.skeleton.nodes[node_idx] node_name = node.name if prop == "name": @@ -136,13 +175,18 @@ def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): return None def rowCount(self, parent): + """Overrides Qt method, returns number of rows.""" return len(self.skeleton.nodes) def columnCount(self, parent): + """Overrides Qt method, returns number of columns.""" return len(SkeletonNodesTableModel._props) - def headerData(self, section, orientation: QtCore.Qt.Orientation, role=Qt.DisplayRole): - if role == Qt.DisplayRole: + def headerData( + self, section, orientation: QtCore.Qt.Orientation, role=QtCore.Qt.DisplayRole + ): + """Overrides Qt method, returns column names.""" + if role == QtCore.Qt.DisplayRole: if orientation == QtCore.Qt.Horizontal: return self._props[section] elif orientation == QtCore.Qt.Vertical: @@ -150,22 +194,24 @@ def headerData(self, section, orientation: QtCore.Qt.Orientation, role=Qt.Displa return None - def setData(self, index: QtCore.QModelIndex, value: str, role=Qt.EditRole): - if role == Qt.EditRole: + def setData(self, index: QtCore.QModelIndex, value: str, role=QtCore.Qt.EditRole): + """Overrides Qt method, updates skeleton with new data from user.""" + if role == QtCore.Qt.EditRole: node_idx = index.row() prop = self._props[index.column()] node_name = self.skeleton.nodes[node_idx].name try: if prop == "name": - if len(value) > 0: + # Change node name (unless empty string) + if value: self._skeleton.relabel_node(node_name, value) - # else: - # self._skeleton.delete_node(node_name) elif prop == "symmetry": - if len(value) > 0: + if value: self._skeleton.add_symmetry(node_name, value) else: - self._skeleton.delete_symmetry(node_name, self._skeleton.get_symmetry(node_name)) + # Value was cleared by user, so delete symmetry + symmetric_to = self._skeleton.get_symmetry(node_name) + self._skeleton.delete_symmetry(node_name, symmetric_to) # send signal that data has changed self.dataChanged.emit(index, index) @@ -178,90 +224,89 @@ def setData(self, index: QtCore.QModelIndex, value: str, role=Qt.EditRole): return False def flags(self, index: QtCore.QModelIndex): - return Qt.ItemIsEnabled | Qt.ItemIsSelectable | Qt.ItemIsEditable + """Overrides Qt method, returns flags (editable etc).""" + return ( + QtCore.Qt.ItemIsEnabled + | QtCore.Qt.ItemIsSelectable + | QtCore.Qt.ItemIsEditable + ) + +class SkeletonEdgesTable(QtWidgets.QTableView): + """Table view widget for skeleton edges.""" -class SkeletonEdgesTable(QTableView): - """Table view widget backed by a custom data model for displaying and - editing Skeleton edges. """ def __init__(self, skeleton: Skeleton): super(SkeletonEdgesTable, self).__init__() self.setModel(SkeletonEdgesTableModel(skeleton)) - self.setSelectionBehavior(QAbstractItemView.SelectRows) - self.setSelectionMode(QAbstractItemView.SingleSelection) -class SkeletonEdgesTableModel(QtCore.QAbstractTableModel): - _props = ["source", "destination"] + self.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows) + self.setSelectionMode(QtWidgets.QAbstractItemView.SingleSelection) + + +class SkeletonEdgesTableModel(GenericTableModel): + """Table model for skeleton edges. + + Args: + skeleton: The skeleton to show in table. + """ def __init__(self, skeleton: Skeleton): - super(SkeletonEdgesTableModel, self).__init__() - self._skeleton = skeleton + props = ("source", "destination") + super(SkeletonEdgesTableModel, self).__init__(props) + self.skeleton = skeleton @property def skeleton(self): + """Gets or sets current skeleton.""" return self._skeleton @skeleton.setter def skeleton(self, val): - self.beginResetModel() self._skeleton = val - self.endResetModel() - - def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): - if role == Qt.DisplayRole and index.isValid(): - idx = index.row() - prop = self._props[index.column()] - edge = self.skeleton.edges[idx] - - if prop == "source": - return edge[0].name - elif prop == "destination": - return edge[1].name - - return None - - def rowCount(self, parent): - return len(self.skeleton.edges) - - def columnCount(self, parent): - return len(SkeletonNodesTableModel._props) - - def headerData(self, section, orientation: QtCore.Qt.Orientation, role=Qt.DisplayRole): - if role == Qt.DisplayRole: - if orientation == QtCore.Qt.Horizontal: - return self._props[section] - elif orientation == QtCore.Qt.Vertical: - return section - - return None - - def flags(self, index: QtCore.QModelIndex): - return Qt.ItemIsEnabled | Qt.ItemIsSelectable + items = [ + dict(source=edge[0].name, destination=edge[1].name) + for edge in self._skeleton.edges + ] + self.items = items +class LabeledFrameTable(QtWidgets.QTableView): + """Table view widget for listing instances in labeled frame.""" - -class LabeledFrameTable(QTableView): - """Table view widget backed by a custom data model for displaying - lists of Video instances. """ - - selectionChangedSignal = QtCore.Signal(int) + selectionChangedSignal = QtCore.Signal(Instance) def __init__(self, labeled_frame: LabeledFrame = None, labels: Labels = None): super(LabeledFrameTable, self).__init__() self.setModel(LabeledFrameTableModel(labeled_frame, labels)) - self.setSelectionBehavior(QAbstractItemView.SelectRows) - self.setSelectionMode(QAbstractItemView.SingleSelection) + self.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows) + self.setSelectionMode(QtWidgets.QAbstractItemView.SingleSelection) def selectionChanged(self, new, old): + """Custom event handler, emits selectionChangedSignal signal.""" super(LabeledFrameTable, self).selectionChanged(new, old) - row_idx = -1 + + instance = None if len(new.indexes()): row_idx = new.indexes()[0].row() - self.selectionChangedSignal.emit(row_idx) + try: + instance = self.model().labeled_frame.instances_to_show[row_idx] + except: + # Usually means that there's no labeled_frame + pass + + self.selectionChangedSignal.emit(instance) class LabeledFrameTableModel(QtCore.QAbstractTableModel): + """Table model for listing instances in labeled frame. + + Allows editing track names. + + Args: + labeled_frame: `LabeledFrame` to show + labels: `Labels` datasource + """ + _props = ("points", "track", "score", "skeleton") def __init__(self, labeled_frame: LabeledFrame, labels: Labels): @@ -271,6 +316,7 @@ def __init__(self, labeled_frame: LabeledFrame, labels: Labels): @property def labeled_frame(self): + """Gets or sets current labeled frame.""" return self._labeled_frame @labeled_frame.setter @@ -281,6 +327,7 @@ def labeled_frame(self, val): @property def labels(self): + """Gets or sets current labels dataset object.""" return self._labels @labels.setter @@ -290,13 +337,15 @@ def labels(self, val): @property def color_manager(self): + """Gets or sets object for determining track colors.""" return self._color_manager @color_manager.setter def color_manager(self, val): self._color_manager = val - def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): + def data(self, index: QtCore.QModelIndex, role=QtCore.Qt.DisplayRole): + """Overrides Qt method, returns data to show in table.""" if index.isValid(): idx = index.row() prop = self._props[index.column()] @@ -305,7 +354,7 @@ def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): instance = self.labeled_frame.instances_to_show[idx] # Cell value - if role == Qt.DisplayRole: + if role == QtCore.Qt.DisplayRole: if prop == "points": return f"{len(instance.nodes)}/{len(instance.skeleton.nodes)}" elif prop == "track" and instance.track is not None: @@ -319,20 +368,31 @@ def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): return "" # Cell color - elif role == Qt.ForegroundRole: + elif role == QtCore.Qt.ForegroundRole: if prop == "track" and instance.track is not None: - return QColor(*self.color_manager.get_color(instance.track)) + return QtGui.QColor( + *self.color_manager.get_color(instance.track) + ) return None def rowCount(self, parent): - return len(self.labeled_frame.instances_to_show) if self.labeled_frame is not None else 0 + """Overrides Qt method, returns number of rows.""" + return ( + len(self.labeled_frame.instances_to_show) + if self.labeled_frame is not None + else 0 + ) def columnCount(self, parent): + """Overrides Qt method, returns number of columns.""" return len(LabeledFrameTableModel._props) - def headerData(self, section, orientation: QtCore.Qt.Orientation, role=Qt.DisplayRole): - if role == Qt.DisplayRole: + def headerData( + self, section, orientation: QtCore.Qt.Orientation, role=QtCore.Qt.DisplayRole + ): + """Overrides Qt method, returns column names.""" + if role == QtCore.Qt.DisplayRole: if orientation == QtCore.Qt.Horizontal: return self._props[section] elif orientation == QtCore.Qt.Vertical: @@ -340,8 +400,11 @@ def headerData(self, section, orientation: QtCore.Qt.Orientation, role=Qt.Displa return None - def setData(self, index: QtCore.QModelIndex, value: str, role=Qt.EditRole): - if role == Qt.EditRole: + def setData(self, index: QtCore.QModelIndex, value: str, role=QtCore.Qt.EditRole): + """ + Overrides Qt method, sets data in labeled frame from user changes. + """ + if role == QtCore.Qt.EditRole: idx = index.row() prop = self._props[index.column()] instance = self.labeled_frame.instances_to_show[idx] @@ -356,18 +419,30 @@ def setData(self, index: QtCore.QModelIndex, value: str, role=Qt.EditRole): return False def flags(self, index: QtCore.QModelIndex): - f = Qt.ItemIsEnabled | Qt.ItemIsSelectable + """Overrides Qt method, returns flags (editable etc).""" + f = QtCore.Qt.ItemIsEnabled | QtCore.Qt.ItemIsSelectable if index.isValid(): idx = index.row() if idx < len(self.labeled_frame.instances_to_show): instance = self.labeled_frame.instances_to_show[idx] prop = self._props[index.column()] if prop == "track" and instance.track is not None: - f |= Qt.ItemIsEditable + f |= QtCore.Qt.ItemIsEditable return f class SkeletonNodeModel(QtCore.QStringListModel): + """ + String list model for source/destination nodes of edges. + + Args: + skeleton: The skeleton for which to list nodes. + src_node: If given, then we assume that this model is being used for + edge destination node. Otherwise, we assume that this model is + being used for an edge source node. + If given, then this should be function that will return the + selected edge source node. + """ def __init__(self, skeleton: Skeleton, src_node: Callable = None): super(SkeletonNodeModel, self).__init__() @@ -376,6 +451,7 @@ def __init__(self, skeleton: Skeleton, src_node: Callable = None): @property def skeleton(self): + """Gets or sets current skeleton.""" return self._skeleton @skeleton.setter @@ -410,41 +486,50 @@ def is_valid_dst(node): return valid_dst_nodes - def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): - if role == Qt.DisplayRole and index.isValid(): + def data(self, index: QtCore.QModelIndex, role=QtCore.Qt.DisplayRole): + """Overrides Qt method, returns data for given row.""" + if role == QtCore.Qt.DisplayRole and index.isValid(): idx = index.row() return self._node_list[idx] return None def rowCount(self, parent): + """Overrides Qt method, returns number of rows.""" return len(self._node_list) def columnCount(self, parent): + """Overrides Qt method, returns number of columns (1).""" return 1 def flags(self, index: QtCore.QModelIndex): - return Qt.ItemIsEnabled | Qt.ItemIsSelectable + """Overrides Qt method, returns flags (editable etc).""" + return QtCore.Qt.ItemIsEnabled | QtCore.Qt.ItemIsSelectable + +class SuggestionsTable(QtWidgets.QTableView): + """Table view widget for showing frame suggestions.""" -class SuggestionsTable(QTableView): - """Table view widget backed by a custom data model for displaying - lists of Video instances. """ def __init__(self, labels): super(SuggestionsTable, self).__init__() self.setModel(SuggestionsTableModel(labels)) - self.setSelectionBehavior(QAbstractItemView.SelectRows) - self.setSelectionMode(QAbstractItemView.SingleSelection) + self.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows) + self.setSelectionMode(QtWidgets.QAbstractItemView.SingleSelection) + self.setSortingEnabled(True) -class SuggestionsTableModel(QtCore.QAbstractTableModel): - _props = ["video", "frame", "labeled",] + +class SuggestionsTableModel(GenericTableModel): + """Table model for showing frame suggestions.""" def __init__(self, labels): - super(SuggestionsTableModel, self).__init__() + props = ("video", "frame", "labeled", "mean score") + + super(SuggestionsTableModel, self).__init__(propList=props) self.labels = labels @property def labels(self): + """Gets or sets current labels dataset.""" return self._labels @labels.setter @@ -452,53 +537,41 @@ def labels(self, val): self.beginResetModel() self._labels = val - self._suggestions_list = self.labels.get_suggestions() + self._data = [] + for video, frame_idx in self.labels.get_suggestions(): + item = dict() + + item[ + "video" + ] = f"{self.labels.videos.index(video)}: {os.path.basename(video.filename)}" + item["frame"] = int(frame_idx) + 1 # start at frame 1 rather than 0 + + # show how many labeled instances are in this frame + val = self._labels.instance_count(video, frame_idx) + val = str(val) if val > 0 else "" + item["labeled"] = val + + # calculate score for frame + scores = [ + inst.score + for lf in self.labels.find(video, frame_idx) + for inst in lf + if hasattr(inst, "score") + ] + val = sum(scores) / len(scores) if scores else None + item["mean score"] = val + + self._data.append(item) self.endResetModel() - def data(self, index: QtCore.QModelIndex, role=Qt.DisplayRole): - if role == Qt.DisplayRole and index.isValid(): - idx = index.row() - prop = self._props[index.column()] - - if idx < self.rowCount(): - video = self._suggestions_list[idx][0] - frame_idx = self._suggestions_list[idx][1] - - if prop == "video": - return os.path.basename(video.filename) # just show the name, not full path - elif prop == "frame": - return int(frame_idx) + 1 # start at frame 1 rather than 0 - elif prop == "labeled": - # show how many labeled instances are in this frame - val = self._labels.instance_count(video, frame_idx) - val = str(val) if val > 0 else "" - return val - - return None - - def rowCount(self, *args): - return len(self._suggestions_list) - - def columnCount(self, *args): - return len(self._props) - - def headerData(self, section, orientation: QtCore.Qt.Orientation, role=Qt.DisplayRole): - if role == Qt.DisplayRole: - if orientation == QtCore.Qt.Horizontal: - return self._props[section] - elif orientation == QtCore.Qt.Vertical: - return section - - return None - - def flags(self, index: QtCore.QModelIndex): - return Qt.ItemIsEnabled | Qt.ItemIsSelectable - if __name__ == "__main__": + from PySide2.QtWidgets import QApplication - labels = Labels.load_json("tests/data/json_format_v2/centered_pair_predictions.json") + labels = Labels.load_json( + "tests/data/json_format_v2/centered_pair_predictions.json" + ) skeleton = labels.labels[0].instances[0].skeleton Labels.save_json(labels, "test.json") @@ -512,4 +585,4 @@ def flags(self, index: QtCore.QModelIndex): table = LabeledFrameTable(labels.labels[0], labels) table.show() - app.exec_() \ No newline at end of file + app.exec_() diff --git a/sleap/gui/formbuilder.py b/sleap/gui/formbuilder.py index 0ae8bcd46..1029a2e33 100644 --- a/sleap/gui/formbuilder.py +++ b/sleap/gui/formbuilder.py @@ -1,6 +1,8 @@ -"""Module for creating a form from a yaml file. +""" +Module for creating a form from a yaml file. Example: + >>> widget = YamlFormWidget(yaml_file="example.yaml") >>> widget.mainAction.connect(my_function) @@ -11,8 +13,11 @@ import yaml +from typing import Any, Dict, List, Optional + from PySide2 import QtWidgets, QtCore + class YamlFormWidget(QtWidgets.QGroupBox): """ Custom QWidget which creates form based on yaml file. @@ -25,10 +30,10 @@ class YamlFormWidget(QtWidgets.QGroupBox): mainAction = QtCore.Signal(dict) valueChanged = QtCore.Signal() - def __init__(self, yaml_file, which_form: str="main", *args, **kwargs): + def __init__(self, yaml_file, which_form: str = "main", *args, **kwargs): super(YamlFormWidget, self).__init__(*args, **kwargs) - with open(yaml_file, 'r') as form_yaml: + with open(yaml_file, "r") as form_yaml: items_to_create = yaml.load(form_yaml, Loader=yaml.SafeLoader) self.which_form = which_form @@ -71,12 +76,13 @@ def trigger_main_action(self): """Emit mainAction signal with form data.""" self.mainAction.emit(self.get_form_data()) + class FormBuilderLayout(QtWidgets.QFormLayout): """ Custom QFormLayout which populates itself from list of form fields. Args: - items_to_create: list which gets passed to get_form_data() + items_to_create: list which gets passed to :meth:`get_form_data` (see there for details about format) """ @@ -96,10 +102,12 @@ def get_form_data(self) -> dict: Dict with key:value for each user-editable widget in layout """ widgets = self.fields.values() - data = {w.objectName(): self.get_widget_value(w) - for w in widgets - if len(w.objectName()) - and type(w) not in (QtWidgets.QLabel, QtWidgets.QPushButton)} + data = { + w.objectName(): self.get_widget_value(w) + for w in widgets + if len(w.objectName()) + and type(w) not in (QtWidgets.QLabel, QtWidgets.QPushButton) + } stacked_data = [w.get_data() for w in widgets if type(w) == StackBuilderWidget] for stack in stacked_data: data.update(stack) @@ -109,7 +117,7 @@ def set_form_data(self, data: dict): """Set specified user-editable data. Args: - data (dict): key should match field name + data: dictionary of datay, key should match field name """ widgets = self.fields for name, val in data.items(): @@ -118,10 +126,11 @@ def set_form_data(self, data: dict): self.set_widget_value(widgets[name], val) else: pass -# print(f"no {name} widget found") + + # print(f"no {name} widget found") @staticmethod - def set_widget_value(widget, val): + def set_widget_value(widget: QtWidgets.QWidget, val): """Set value for specific widget.""" # if widget.property("field_data_type") == "sci": # val = str(val) @@ -140,11 +149,13 @@ def set_widget_value(widget, val): widget.repaint() @staticmethod - def get_widget_value(widget): - """Get value of form field (using whichever method appropriate for widget). + def get_widget_value(widget: QtWidgets.QWidget) -> Any: + """Returns value of form field. + + This determines the method appropriate for the type of widget. Args: - widget: subclass of QtWidget + widget: The widget for which to return value. Returns: value (can be bool, numeric, string, or None) """ @@ -169,26 +180,26 @@ def get_widget_value(widget): val = None if val == "None" else val return val - def build_form(self, items_to_create): - """Add widgets to form layout for each item in items_to_create. + def build_form(self, items_to_create: List[Dict[str, Any]]): + """Adds widgets to form layout for each item in items_to_create. Args: - items_to_create: list of dicts with fields + items_to_create: list of dictionaries with keys + * name: used as key when we return form data as dict * label: string to show in form * type: supports double, int, bool, list, button, stack * default: default value for form field - * [options]: comma separated list of options, used for list or stack + * [options]: comma separated list of options, + used for list or stack field-types * for stack, array of dicts w/ form data for each stack page - Note: a "stack" has a dropdown menu that determines which stack page to show + A "stack" has a dropdown menu that determines which stack page to show. Returns: None. """ for item in items_to_create: - field = None - # double: show spinbox (number w/ up/down controls) if item["type"] == "double": field = QtWidgets.QDoubleSpinBox() @@ -259,22 +270,31 @@ def build_form(self, items_to_create): if item["type"].split("_")[0] == "file": self.addRow("", self._make_file_button(item, field)) - def _make_file_button(self, item, field): - file_button = QtWidgets.QPushButton("Select "+item["label"]) + def _make_file_button( + self, item: Dict, field: QtWidgets.QWidget + ) -> QtWidgets.QPushButton: + """Creates the button for a file_* field-type.""" + file_button = QtWidgets.QPushButton("Select " + item["label"]) if item["type"].split("_")[-1] == "open": # Define function for button to trigger def select_file(*args, x=field): filter = item.get("filter", "Any File (*.*)") - filename, _ = QtWidgets.QFileDialog.getOpenFileName(None, directory=None, caption="Open File", filter=filter) - if len(filename): x.setText(filename) + filename, _ = QtWidgets.QFileDialog.getOpenFileName( + None, directory=None, caption="Open File", filter=filter + ) + if len(filename): + x.setText(filename) self.valueChanged.emit() elif item["type"].split("_")[-1] == "dir": # Define function for button to trigger def select_file(*args, x=field): - filename = QtWidgets.QFileDialog.getExistingDirectory(None, directory=None, caption="Open File") - if len(filename): x.setText(filename) + filename = QtWidgets.QFileDialog.getExistingDirectory( + None, directory=None, caption="Open File" + ) + if len(filename): + x.setText(filename) self.valueChanged.emit() else: @@ -283,7 +303,19 @@ def select_file(*args, x=field): file_button.clicked.connect(select_file) return file_button + class StackBuilderWidget(QtWidgets.QWidget): + """ + A custom widget that shows different subforms depending on menu selection. + + Args: + stack_data: Dictionary for field from `items_to_create`. + The "options" key will give the list of options to show in + menu. Each of the "options" will also be the key of a dictionary + within stack_data that has the same structure as the dictionary + passed to :meth:`FormBuilderLayout.build_form()`. + """ + def __init__(self, stack_data, *args, **kwargs): super(StackBuilderWidget, self).__init__(*args, **kwargs) @@ -291,7 +323,9 @@ def __init__(self, stack_data, *args, **kwargs): self.combo_box = QtWidgets.QComboBox() self.stacked_widget = ResizingStackedWidget() - self.combo_box.activated.connect(lambda x: self.stacked_widget.setCurrentIndex(x)) + self.combo_box.activated.connect( + lambda x: self.stacked_widget.setCurrentIndex(x) + ) self.page_layouts = dict() @@ -323,19 +357,35 @@ def __init__(self, stack_data, *args, **kwargs): self.setLayout(multi_layout) def value(self): + """Returns value of menu.""" return self.combo_box.currentText() def get_data(self): + """Returns value from currently shown subform.""" return self.page_layouts[self.value()].get_form_data() class FieldComboWidget(QtWidgets.QComboBox): + """ + A custom ComboBox widget with method to easily add set of options. + """ + def __init__(self, *args, **kwargs): super(FieldComboWidget, self).__init__(*args, **kwargs) self.setSizeAdjustPolicy(QtWidgets.QComboBox.AdjustToContents) self.setMinimumContentsLength(3) - def set_options(self, options_list, select_item=None): + def set_options(self, options_list: List[str], select_item: Optional[str] = None): + """ + Sets list of menu options. + + Args: + options_list: List of items (strings) to show in menu. + select_item: Item to select initially. + + Returns: + None. + """ self.clear() for item in options_list: if item == "---": @@ -348,11 +398,17 @@ def set_options(self, options_list, select_item=None): class ResizingStackedWidget(QtWidgets.QStackedWidget): + """ + QStackedWidget that updates its sizeHint and minimumSizeHint as needed. + """ + def __init__(self, *args, **kwargs): super(ResizingStackedWidget, self).__init__(*args, **kwargs) def sizeHint(self): + """Qt method.""" return self.currentWidget().sizeHint() def minimumSizeHint(self): + """Qt method.""" return self.currentWidget().minimumSizeHint() diff --git a/sleap/gui/importvideos.py b/sleap/gui/importvideos.py index 3ba51fa99..84d478e83 100644 --- a/sleap/gui/importvideos.py +++ b/sleap/gui/importvideos.py @@ -17,25 +17,33 @@ method while passing the user-selected params as the named parameters: >>> vid = item["video_class"](**item["params"]) + """ from PySide2.QtCore import Qt, QRectF, Signal from PySide2.QtWidgets import QApplication, QLayout, QVBoxLayout, QHBoxLayout, QFrame from PySide2.QtWidgets import QFileDialog, QDialog, QWidget, QLabel, QScrollArea -from PySide2.QtWidgets import QPushButton, QButtonGroup, QRadioButton, QCheckBox, QComboBox, QStackedWidget +from PySide2.QtWidgets import ( + QPushButton, + QButtonGroup, + QRadioButton, + QCheckBox, + QComboBox, +) from sleap.gui.video import GraphicsView -from sleap.io.video import Video, HDF5Video, MediaVideo +from sleap.io.video import Video import h5py import qimage2ndarray + class ImportVideos: """Class to handle video importing UI.""" - + def __init__(self): self.result = [] - + def go(self): """Runs the import UI. @@ -48,20 +56,21 @@ def go(self): List with dict of the parameters for each file to import. """ dialog = QFileDialog() - #dialog.setOption(QFileDialog.Option.DontUseNativeDialogs, True) + # dialog.setOption(QFileDialog.Option.DontUseNativeDialogs, True) file_names, filter = dialog.getOpenFileNames( - None, - "Select videos to import...", # dialogue title - ".", # initial path - "Any Video (*.h5 *.hd5v *.mp4 *.avi *.json);;HDF5 (*.h5 *.hd5v);;ImgStore (*.json);;Media Video (*.mp4 *.avi);;Any File (*.*)", # filters - #options=QFileDialog.DontUseNativeDialog - ) + None, + "Select videos to import...", # dialogue title + ".", # initial path + "Any Video (*.h5 *.hd5v *.mp4 *.avi *.json);;HDF5 (*.h5 *.hd5v);;ImgStore (*.json);;Media Video (*.mp4 *.avi);;Any File (*.*)", # filters + # options=QFileDialog.DontUseNativeDialog + ) if len(file_names) > 0: importer = ImportParamDialog(file_names) - importer.accepted.connect(lambda:importer.get_data(self.result)) + importer.accepted.connect(lambda: importer.get_data(self.result)) importer.exec_() return self.result + class ImportParamDialog(QDialog): """Dialog for selecting parameters with preview when importing video. @@ -69,13 +78,13 @@ class ImportParamDialog(QDialog): file_names (list): List of files we want to import. """ - def __init__(self, file_names:list, *args, **kwargs): + def __init__(self, file_names: list, *args, **kwargs): super(ImportParamDialog, self).__init__(*args, **kwargs) - + self.import_widgets = [] - + self.setWindowTitle("Video Import Options") - + self.import_types = [ { "video_type": "hdf5", @@ -85,55 +94,53 @@ def __init__(self, file_names:list, *args, **kwargs): { "name": "dataset", "type": "function_menu", - "options": "_get_h5_dataset_options" + "options": "_get_h5_dataset_options", + "required": True, }, { "name": "input_format", "type": "radio", - "options": "channels_first,channels_last" - } - ] + "options": "channels_first,channels_last", + }, + ], }, { "video_type": "mp4", "match": "mp4,avi", "video_class": Video.from_media, - "params": [ - { - "name": "grayscale", - "type": "check" - } - ] + "params": [{"name": "grayscale", "type": "check"}], }, { "video_type": "numpy", "match": "npy", "video_class": Video.from_numpy, - "params": [] + "params": [], }, { "video_type": "imgstore", "match": "json", "video_class": Video.from_filename, - "params": [] - } + "params": [], + }, ] - + outer_layout = QVBoxLayout() - + scroll_widget = QScrollArea() - #scroll_widget.setWidgetResizable(False) + # scroll_widget.setWidgetResizable(False) scroll_widget.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn) scroll_widget.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) - + scroll_items_widget = QWidget() scroll_layout = QVBoxLayout() for file_name in file_names: if file_name: this_type = None for import_type in self.import_types: - if import_type.get("match",None) is not None: - if file_name.lower().endswith(tuple(import_type["match"].split(","))): + if import_type.get("match", None) is not None: + if file_name.lower().endswith( + tuple(import_type["match"].split(",")) + ): this_type = import_type break if this_type is not None: @@ -145,7 +152,7 @@ def __init__(self, file_names:list, *args, **kwargs): scroll_items_widget.setLayout(scroll_layout) scroll_widget.setWidget(scroll_items_widget) outer_layout.addWidget(scroll_widget) - + button_layout = QHBoxLayout() cancel_button = QPushButton("Cancel") import_button = QPushButton("Import") @@ -153,15 +160,15 @@ def __init__(self, file_names:list, *args, **kwargs): button_layout.addStretch() button_layout.addWidget(cancel_button) button_layout.addWidget(import_button) - + outer_layout.addLayout(button_layout) - + self.setLayout(outer_layout) - + import_button.clicked.connect(self.accept) cancel_button.clicked.connect(self.reject) - def get_data(self, import_result = None): + def get_data(self, import_result=None): """Method to get results from import. Args: @@ -186,6 +193,7 @@ def paint(self, painter, option, widget=None): """Method required by Qt.""" pass + class ImportItemWidget(QFrame): """Widget for selecting parameters with preview when importing video. @@ -193,39 +201,41 @@ class ImportItemWidget(QFrame): file_path (str): Full path to selected video file. import_type (dict): Data about user-selectable import parameters. """ - + def __init__(self, file_path: str, import_type: dict, *args, **kwargs): super(ImportItemWidget, self).__init__(*args, **kwargs) - + self.file_path = file_path self.import_type = import_type self.video = None - + import_item_layout = QVBoxLayout() - + self.enabled_checkbox_widget = QCheckBox(self.file_path) self.enabled_checkbox_widget.setChecked(True) import_item_layout.addWidget(self.enabled_checkbox_widget) - - #import_item_layout.addWidget(QLabel(self.file_path)) + + # import_item_layout.addWidget(QLabel(self.file_path)) inner_layout = QHBoxLayout() - self.options_widget = ImportParamWidget(parent=self, file_path = self.file_path, import_type = self.import_type) + self.options_widget = ImportParamWidget( + parent=self, file_path=self.file_path, import_type=self.import_type + ) self.preview_widget = VideoPreviewWidget(parent=self) self.preview_widget.setFixedSize(200, 200) - + self.enabled_checkbox_widget.stateChanged.connect( - lambda state:self.options_widget.setEnabled(state == Qt.Checked) + lambda state: self.options_widget.setEnabled(state == Qt.Checked) ) - + inner_layout.addWidget(self.options_widget) inner_layout.addWidget(self.preview_widget) import_item_layout.addLayout(inner_layout) self.setLayout(import_item_layout) - + self.setFrameStyle(QFrame.Panel) - + self.options_widget.changed.connect(self.update_video) - self.update_video() + self.update_video(initial=True) def is_enabled(self): """Am I enabled? @@ -244,35 +254,42 @@ def get_data(self) -> dict: Returns: Dict with data for this video. """ - + video_data = { - "params": self.options_widget.get_values(), - "video_type": self.import_type["video_type"], - "video_class": self.import_type["video_class"], - } + "params": self.options_widget.get_values(), + "video_type": self.import_type["video_type"], + "video_class": self.import_type["video_class"], + } return video_data - def update_video(self): + def update_video(self, initial=False): """Update preview video using current param values. - + + Args: + initial: if True, then get video settings that are used by + the `Video` object when they aren't specified as params Returns: None. """ - - video_params = self.options_widget.get_values() + + video_params = self.options_widget.get_values(only_required=initial) + try: if self.import_type["video_class"] is not None: self.video = self.import_type["video_class"](**video_params) else: self.video = None - + self.preview_widget.load_video(self.video) except Exception as e: - print(e) + print(f"Unable to load video with these parameters. Error: {e}") # if we got an error showing video with those settings, clear the video preview self.video = None self.preview_widget.clear_video() + if initial and self.video is not None: + self.options_widget.set_values_from_video(self.video) + def boundingRect(self) -> QRectF: """Method required by Qt.""" return QRectF() @@ -281,6 +298,7 @@ def paint(self, painter, option, widget=None): """Method required by Qt.""" pass + class ImportParamWidget(QWidget): """Widget for allowing user to select video parameters. @@ -294,29 +312,29 @@ class ImportParamWidget(QWidget): changed = Signal() - def __init__(self, file_path:str, import_type:dict, *args, **kwargs): + def __init__(self, file_path: str, import_type: dict, *args, **kwargs): super(ImportParamWidget, self).__init__(*args, **kwargs) - + self.file_path = file_path self.import_type = import_type self.widget_elements = {} self.video_params = {} - + option_layout = self.make_layout() - #self.changed.connect( lambda: print(self.get_values()) ) - + # self.changed.connect( lambda: print(self.get_values()) ) + self.setLayout(option_layout) - + def make_layout(self) -> QLayout: """Builds the layout of widgets for user-selected import parameters.""" - + param_list = self.import_type["params"] widget_layout = QVBoxLayout() widget_elements = dict() for param_item in param_list: name = param_item["name"] type = param_item["type"] - options = param_item.get("options",None) + options = param_item.get("options", None) if type == "radio": radio_group = QButtonGroup(parent=self) option_list = options.split(",") @@ -327,11 +345,11 @@ def make_layout(self) -> QLayout: btn_widget.setChecked(True) widget_layout.addWidget(btn_widget) radio_group.addButton(btn_widget) - radio_group.buttonToggled.connect(lambda:self.changed.emit()) + radio_group.buttonToggled.connect(lambda: self.changed.emit()) widget_elements[name] = radio_group elif type == "check": check_widget = QCheckBox(name) - check_widget.stateChanged.connect(lambda:self.changed.emit()) + check_widget.stateChanged.connect(lambda: self.changed.emit()) widget_layout.addWidget(check_widget) widget_elements[name] = check_widget elif type == "function_menu": @@ -340,17 +358,18 @@ def make_layout(self) -> QLayout: option_list = getattr(self, options)() for option in option_list: list_widget.addItem(option) - list_widget.currentIndexChanged.connect(lambda:self.changed.emit()) + list_widget.currentIndexChanged.connect(lambda: self.changed.emit()) widget_layout.addWidget(list_widget) widget_elements[name] = list_widget self.widget_elements = widget_elements return widget_layout - - def get_values(self): + + def get_values(self, only_required=False): """Method to get current user-selected values for import parameters. Args: - None. + only_required: Only return the parameters that are required + for instantiating `Video` object Returns: Dict of param keys/values. @@ -366,16 +385,39 @@ def get_values(self): for param_item in param_list: name = param_item["name"] type = param_item["type"] - value = None - if type == "radio": - value = self.widget_elements[name].checkedButton().text() - elif type == "check": - value = self.widget_elements[name].isChecked() - elif type == "function_menu": - value = self.widget_elements[name].currentText() - param_values[name] = value + is_required = param_item.get("required", False) + + if not only_required or is_required: + value = None + if type == "radio": + value = self.widget_elements[name].checkedButton().text() + elif type == "check": + value = self.widget_elements[name].isChecked() + elif type == "function_menu": + value = self.widget_elements[name].currentText() + param_values[name] = value return param_values - + + def set_values_from_video(self, video): + """Set the form fields using attributes on video.""" + param_list = self.import_type["params"] + for param in param_list: + name = param["name"] + type = param["type"] + print(name, type) + if hasattr(video, name): + val = getattr(video, name) + print(name, val) + widget = self.widget_elements[name] + if hasattr(widget, "isChecked"): + widget.setChecked(val) + elif hasattr(widget, "value"): + widget.setValue(val) + elif hasattr(widget, "currentText"): + widget.setCurrentText(str(val)) + elif hasattr(widget, "text"): + widget.setText(str(val)) + def _get_h5_dataset_options(self) -> list: """Method to get a list of all datasets in hdf5 file. @@ -389,12 +431,12 @@ def _get_h5_dataset_options(self) -> list: This is used to populate the "function_menu"-type param. """ try: - with h5py.File(self.file_path,"r") as f: - options = self._find_h5_datasets("",f) + with h5py.File(self.file_path, "r") as f: + options = self._find_h5_datasets("", f) except Exception as e: options = [] return options - + def _find_h5_datasets(self, data_path, data_object) -> list: """Recursively find datasets in hdf5 file.""" options = [] @@ -403,7 +445,9 @@ def _find_h5_datasets(self, data_path, data_object) -> list: if len(data_object[key].shape) == 4: options.append(data_path + "/" + key) elif isinstance(data_object[key], h5py._hl.group.Group): - options.extend(self._find_h5_datasets(data_path + "/" + key, data_object[key])) + options.extend( + self._find_h5_datasets(data_path + "/" + key, data_object[key]) + ) return options def boundingRect(self) -> QRectF: @@ -439,28 +483,33 @@ def __init__(self, video: Video = None, *args, **kwargs): self.layout.addWidget(self.video_label) self.setLayout(self.layout) self.view.show() - + if video is not None: self.load_video(video) - + def clear_video(self): """Clear the video preview.""" self.view.clear() - + def load_video(self, video: Video, initial_frame=0, plot=True): """Load the video preview and display label text.""" self.video = video self.frame_idx = initial_frame - label = "(%d, %d), %d f, %d c" % (self.video.width, self.video.height, self.video.frames, self.video.channels) + label = "(%d, %d), %d f, %d c" % ( + self.video.width, + self.video.height, + self.video.frames, + self.video.channels, + ) self.video_label.setText(label) if plot: self.plot(initial_frame) - + def plot(self, idx=0): """Show the video preview.""" if self.video is None: return - + # Get image data frame = self.video.get_frame(idx) # Clear existing objects @@ -482,9 +531,12 @@ def paint(self, painter, option, widget=None): if __name__ == "__main__": app = QApplication([]) - + import_list = ImportVideos().go() - + for import_item in import_list: vid = import_item["video_class"](**import_item["params"]) - print("Imported video data: (%d, %d), %d f, %d c" % (vid.width, vid.height, vid.frames, vid.channels)) + print( + "Imported video data: (%d, %d), %d f, %d c" + % (vid.width, vid.height, vid.frames, vid.channels) + ) diff --git a/sleap/gui/merge.py b/sleap/gui/merge.py new file mode 100644 index 000000000..51ba7a82a --- /dev/null +++ b/sleap/gui/merge.py @@ -0,0 +1,317 @@ +""" +Gui for merging two labels files with options to resolve conflicts. +""" + +import attr + +from typing import Dict, List + +from sleap.instance import LabeledFrame +from sleap.io.dataset import Labels + +from PySide2 import QtWidgets, QtCore + +USE_BASE_STRING = "Use base, discard conflicting new instances" +USE_NEW_STRING = "Use new, discard conflicting base instances" +USE_NEITHER_STRING = "Discard all conflicting instances" +CLEAN_STRING = "Accept clean merge" + + +class MergeDialog(QtWidgets.QDialog): + """ + Dialog window for complex merging of two SLEAP datasets. + + This will immediately merge any labeled frames that can be cleanly merged, + show summary of merge and prompt user about how to handle merge conflict, + and then finish merge (resolving conflicts as the user requested). + """ + + def __init__(self, base_labels: Labels, new_labels: Labels, *args, **kwargs): + """ + Creates merge dialog and begins merging. + + Args: + base_labels: The base dataset into which we're inserting data. + new_labels: New dataset from which we're getting data to insert. + + Returns: + None. + """ + + super(MergeDialog, self).__init__(*args, **kwargs) + + self.base_labels = base_labels + self.new_labels = new_labels + + merged, self.extra_base, self.extra_new = Labels.complex_merge_between( + self.base_labels, self.new_labels + ) + + merge_total = 0 + merge_frames = 0 + for vid_frame_list in merged.values(): + # number of frames for this video + merge_frames += len(vid_frame_list.keys()) + # number of instances across frames for this video + merge_total += sum((map(len, vid_frame_list.values()))) + + layout = QtWidgets.QVBoxLayout() + + merged_text = f"Cleanly merged {merge_total} instances" + if merge_total: + merged_text += f" across {merge_frames} frames" + merged_text += "." + merged_label = QtWidgets.QLabel(merged_text) + layout.addWidget(merged_label) + + if merge_total: + merge_table = MergeTable(merged) + layout.addWidget(merge_table) + + if not self.extra_base: + conflict_text = "There are no conflicts." + else: + conflict_text = "Merge conflicts:" + + conflict_label = QtWidgets.QLabel(conflict_text) + layout.addWidget(conflict_label) + + if self.extra_base: + conflict_table = ConflictTable( + self.base_labels, self.extra_base, self.extra_new + ) + layout.addWidget(conflict_table) + + self.merge_method = QtWidgets.QComboBox() + if self.extra_base: + self.merge_method.addItem(USE_NEW_STRING) + self.merge_method.addItem(USE_BASE_STRING) + self.merge_method.addItem(USE_NEITHER_STRING) + else: + self.merge_method.addItem(CLEAN_STRING) + layout.addWidget(self.merge_method) + + buttons = QtWidgets.QDialogButtonBox() + buttons.addButton("Finish Merge", QtWidgets.QDialogButtonBox.AcceptRole) + buttons.accepted.connect(self.finishMerge) + + layout.addWidget(buttons) + + self.setLayout(layout) + + def finishMerge(self): + """ + Finishes merge process, possibly resolving conflicts. + + This is connected to `accepted` signal. + + Args: + None. + + Raises: + ValueError: If no valid merge method was selected in dialog. + + Returns: + None. + """ + merge_method = self.merge_method.currentText() + if merge_method == USE_BASE_STRING: + Labels.finish_complex_merge(self.base_labels, self.extra_base) + elif merge_method == USE_NEW_STRING: + Labels.finish_complex_merge(self.base_labels, self.extra_new) + elif merge_method in (USE_NEITHER_STRING, CLEAN_STRING): + Labels.finish_complex_merge(self.base_labels, []) + else: + raise ValueError("No valid merge method selected.") + + self.accept() + + +class ConflictTable(QtWidgets.QTableView): + """ + Qt table view for summarizing merge conflicts. + + Arguments are passed through to the table view object. + + The two lists of `LabeledFrame` objects should be correlated (idx in one will + match idx of the conflicting frame in other). + + Args: + base_labels: The base dataset. + extra_base: `LabeledFrame` objects from base that conflicted. + extra_new: `LabeledFrame` objects from new dataset that conflicts. + """ + + def __init__( + self, + base_labels: Labels, + extra_base: List[LabeledFrame], + extra_new: List[LabeledFrame], + ): + super(ConflictTable, self).__init__() + self.setModel(ConflictTableModel(base_labels, extra_base, extra_new)) + + +class ConflictTableModel(QtCore.QAbstractTableModel): + """Qt table model for summarizing merge conflicts. + + See :class:`ConflictTable`. + """ + + _props = ["video", "frame", "base", "new"] + + def __init__( + self, + base_labels: Labels, + extra_base: List[LabeledFrame], + extra_new: List[LabeledFrame], + ): + super(ConflictTableModel, self).__init__() + self.base_labels = base_labels + self.extra_base = extra_base + self.extra_new = extra_new + + def data(self, index: QtCore.QModelIndex, role=QtCore.Qt.DisplayRole): + """Required by Qt.""" + if role == QtCore.Qt.DisplayRole and index.isValid(): + idx = index.row() + prop = self._props[index.column()] + + if idx < self.rowCount(): + if prop == "video": + return self.extra_base[idx].video.filename + if prop == "frame": + return self.extra_base[idx].frame_idx + if prop == "base": + return show_instance_type_counts(self.extra_base[idx]) + if prop == "new": + return show_instance_type_counts(self.extra_new[idx]) + + return None + + def rowCount(self, *args): + """Required by Qt.""" + return len(self.extra_base) + + def columnCount(self, *args): + """Required by Qt.""" + return len(self._props) + + def headerData( + self, section, orientation: QtCore.Qt.Orientation, role=QtCore.Qt.DisplayRole + ): + """Required by Qt.""" + if role == QtCore.Qt.DisplayRole: + if orientation == QtCore.Qt.Horizontal: + return self._props[section] + elif orientation == QtCore.Qt.Vertical: + return section + return None + + +class MergeTable(QtWidgets.QTableView): + """ + Qt table view for summarizing cleanly merged frames. + + Arguments are passed through to the table view object. + + Args: + merged: The frames that were cleanly merged. + See :meth:`Labels.complex_merge_between` for details. + """ + + def __init__(self, merged, *args, **kwargs): + super(MergeTable, self).__init__() + self.setModel(MergeTableModel(merged)) + + +class MergeTableModel(QtCore.QAbstractTableModel): + """Qt table model for summarizing merge conflicts. + + See :class:`MergeTable`. + """ + + _props = ["video", "frame", "merged instances"] + + def __init__(self, merged: Dict["Video", Dict[int, List["Instance"]]]): + super(MergeTableModel, self).__init__() + self.merged = merged + + self.data_table = [] + for video in self.merged.keys(): + for frame_idx, frame_instance_list in self.merged[video].items(): + self.data_table.append( + dict( + filename=video.filename, + frame_idx=frame_idx, + instances=frame_instance_list, + ) + ) + + def data(self, index: QtCore.QModelIndex, role=QtCore.Qt.DisplayRole): + """Required by Qt.""" + if role == QtCore.Qt.DisplayRole and index.isValid(): + idx = index.row() + prop = self._props[index.column()] + + if idx < self.rowCount(): + if prop == "video": + return self.data_table[idx]["filename"] + if prop == "frame": + return self.data_table[idx]["frame_idx"] + if prop == "merged instances": + return show_instance_type_counts(self.data_table[idx]["instances"]) + + return None + + def rowCount(self, *args): + """Required by Qt.""" + return len(self.data_table) + + def columnCount(self, *args): + """Required by Qt.""" + return len(self._props) + + def headerData( + self, section, orientation: QtCore.Qt.Orientation, role=QtCore.Qt.DisplayRole + ): + """Required by Qt.""" + if role == QtCore.Qt.DisplayRole: + if orientation == QtCore.Qt.Horizontal: + return self._props[section] + elif orientation == QtCore.Qt.Vertical: + return section + return None + + +def show_instance_type_counts(instance_list: List["Instance"]) -> str: + """ + Returns string of instance counts to show in table. + + Args: + instance_list: The list of instances to count. + + Returns: + String with numbers of user/predicted instances. + """ + prediction_count = len( + list(filter(lambda inst: hasattr(inst, "score"), instance_list)) + ) + user_count = len(instance_list) - prediction_count + return f"{user_count}/{prediction_count}" + + +if __name__ == "__main__": + + # file_a = "tests/data/json_format_v1/centered_pair.json" + # file_b = "tests/data/json_format_v2/centered_pair_predictions.json" + file_a = "files/merge/a.h5" + file_b = "files/merge/b.h5" + + base_labels = Labels.load_file(file_a) + new_labels = Labels.load_file(file_b) + + app = QtWidgets.QApplication() + win = MergeDialog(base_labels, new_labels) + win.show() + app.exec_() diff --git a/sleap/gui/multicheck.py b/sleap/gui/multicheck.py index 082df6911..5af7137e5 100644 --- a/sleap/gui/multicheck.py +++ b/sleap/gui/multicheck.py @@ -1,31 +1,43 @@ """ -Module for Qt Widget to show multiple checkboxes for selecting from a sequence of numbers. +Module for Qt Widget to show multiple checkboxes for selecting. Example: >>> mc = MultiCheckWidget(count=5, selected=[0,1],title="My Items") - >>> me.selectionChanged.connect(window.plot) + >>> mc.selectionChanged.connect(window.plot) >>> window.layout.addWidget(mc) """ + +from typing import List, Optional + from PySide2.QtCore import QRectF, Signal from PySide2.QtWidgets import QGridLayout, QGroupBox, QButtonGroup, QCheckBox + class MultiCheckWidget(QGroupBox): - """Qt Widget to show multiple checkboxes for selecting from a sequence of numbers. + """Qt Widget to show multiple checkboxes for a sequence of numbers. Args: - count (int): The number of checkboxes to show. - title (str, optional): Display title for group of checkboxes. - selected (list, optional): List of checkbox numbers to initially have checked. - default (bool, optional): Default to checked/unchecked (ignored if selected arg given). + count: The number of checkboxes to show. + title: Display title for group of checkboxes. + selected: List of checkbox numbers to initially check. + default: Whether to default boxes as checked. """ - def __init__(self, *args, count, title="", selected=None, default=False, **kwargs): + def __init__( + self, + *args, + count: int, + title: Optional[str] = "", + selected: Optional[List] = None, + default: Optional[bool] = False, + **kwargs + ): super(MultiCheckWidget, self).__init__(*args, **kwargs) # QButtonGroup is the logical container # it allows us to get list of checked boxes more easily self.check_group = QButtonGroup() - self.check_group.setExclusive(False) # more than one can be checked + self.check_group.setExclusive(False) # more than one can be checked if title != "": self.setTitle(title) @@ -39,15 +51,15 @@ def __init__(self, *args, count, title="", selected=None, default=False, **kwarg check_layout = QGridLayout() self.setLayout(check_layout) for i in range(count): - check = QCheckBox("%d"%(i)) + check = QCheckBox("%d" % (i)) # call signal/slot on self when one of the checkboxes is changed check.stateChanged.connect(lambda e: self.selectionChanged.emit()) self.check_group.addButton(check, i) - check_layout.addWidget(check, i//8, i%8) + check_layout.addWidget(check, i // 8, i % 8) self.setSelected(selected) """ - selectionChanged signal is sent whenever one of the checkboxes gets a stateChanged signal. + selectionChanged signal sent when a checkbox gets a stateChanged signal """ selectionChanged = Signal() @@ -67,7 +79,7 @@ def setSelected(self, selected: list): """Method to set some checkboxes as checked. Args: - selected (list): List of checkboxes to check. + selected: List of checkboxes to check. Returns: None diff --git a/sleap/gui/overlays/__init__.py b/sleap/gui/overlays/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sleap/gui/overlays/anchors.py b/sleap/gui/overlays/anchors.py new file mode 100644 index 000000000..1914545f5 --- /dev/null +++ b/sleap/gui/overlays/anchors.py @@ -0,0 +1,52 @@ +""" +Module with overlay for showing negative training sample anchors. +""" +import attr + +from PySide2 import QtGui + +from sleap.gui.video import QtVideoPlayer +from sleap.io.dataset import Labels + + +@attr.s(auto_attribs=True) +class NegativeAnchorOverlay: + """Class to overlay of negative training sample anchors to video frame. + + Attributes: + labels: The :class:`Labels` dataset from which to get overlay data. + player: The video player in which to show overlay. + """ + + labels: Labels = None + player: QtVideoPlayer = None + _pen = QtGui.QPen(QtGui.QColor("red")) + _line_len: int = 3 + + def add_to_scene(self, video, frame_idx): + """Adds anchor markers as overlay on frame image.""" + if self.labels is None: + return + if video not in self.labels.negative_anchors: + return + + anchors = self.labels.negative_anchors[video] + for idx, x, y in anchors: + if frame_idx == idx: + self._add(x, y) + + def _add(self, x, y): + self.player.scene.addLine( + x - self._line_len, + y - self._line_len, + x + self._line_len, + y + self._line_len, + self._pen, + ) + self.player.scene.addLine( + x + self._line_len, + y - self._line_len, + x - self._line_len, + y + self._line_len, + self._pen, + ) diff --git a/sleap/gui/overlays/base.py b/sleap/gui/overlays/base.py index 0aac7faf8..5d4d92549 100644 --- a/sleap/gui/overlays/base.py +++ b/sleap/gui/overlays/base.py @@ -6,56 +6,71 @@ import numpy as np from typing import Sequence +import sleap from sleap.io.video import Video, HDF5Video from sleap.gui.video import QtVideoPlayer from sleap.nn.transform import DataTransform + class HDF5Data(HDF5Video): def __getitem__(self, i): """Get data for frame i from `HDF5Video` object.""" x = self.get_frame(i) - return np.clip(x,0,1) + return np.clip(x, 0, 1) + @attr.s(auto_attribs=True) class ModelData: - model: 'keras.Model' + inference_model: "sleap.nn.inference.InferenceModel" video: Video - do_rescale: bool=False + do_rescale: bool = False + output_scale: float = 1.0 + adjust_vals: bool = True def __getitem__(self, i): """Data data for frame i from predictor.""" frame_img = self.video[i] # Trim to size that works for model - frame_img = frame_img[:, :self.video.height//8*8, :self.video.width//8*8, :] - + size_reduction = 2 ** (self.inference_model.down_blocks) + input_size = ( + (self.video.height // size_reduction) * size_reduction, + (self.video.width // size_reduction) * size_reduction, + self.video.channels, + ) + frame_img = frame_img[:, : input_size[0], : input_size[1], :] + + inference_transform = DataTransform() if self.do_rescale: # Scale input image if model trained on scaled images - inference_transform = DataTransform() frame_img = inference_transform.scale_to( - imgs=frame_img, - target_size=self.model.input_shape[1:3]) + imgs=frame_img, target_size=self.inference_model.input_shape[1:3] + ) # Get predictions - frame_result = self.model.predict(frame_img.astype("float32") / 255) - if self.do_rescale: + frame_result = self.inference_model.predict(frame_img) + if self.do_rescale or self.output_scale != 1.0: + inference_transform.scale *= self.output_scale frame_result = inference_transform.invert_scale(frame_result) # We just want the single image results - frame_result = frame_result[0] + if type(i) != slice: + frame_result = frame_result[0] - # If max value is below 1, amplify values so max is 1. - # This allows us to visualize model with small ptp value - # even though this model may not give us adequate predictions. - max_val = np.max(frame_result) - if max_val < 1: - frame_result = frame_result/np.max(frame_result) + if self.adjust_vals: + # If max value is below 1, amplify values so max is 1. + # This allows us to visualize model with small ptp value + # even though this model may not give us adequate predictions. + max_val = np.max(frame_result) + if max_val < 1: + frame_result = frame_result / np.max(frame_result) - # Clip values to ensure that they're within [0, 1] - frame_result = np.clip(frame_result, 0, 1) + # Clip values to ensure that they're within [0, 1] + frame_result = np.clip(frame_result, 0, 1) return frame_result + @attr.s(auto_attribs=True) class DataOverlay: @@ -65,7 +80,24 @@ class DataOverlay: transform: DataTransform = None def add_to_scene(self, video, frame_idx): - if self.data is None: return + if self.data is None: + return + + # Check if video matches video for ModelData object + if hasattr(self.data, "video") and self.data.video != video: + video_shape = (video.height, video.width, video.channels) + prior_shape = ( + self.data.video.height, + self.data.video.width, + self.data.video.channels, + ) + # Check if the videos are both compatible with the loaded model + if video_shape == prior_shape: + # Shapes match so we can apply model to this video + self.data.video = video + else: + # Shapes don't match so don't do anything with this video + return if self.transform is None: self._add(self.player.view.scene, self.overlay_class(self.data[frame_idx])) @@ -73,8 +105,11 @@ def add_to_scene(self, video, frame_idx): else: # If data indices are different than frame indices, use data # index; otherwise just use frame index. - idxs = self.transform.get_data_idxs(frame_idx) \ - if self.transform.frame_idxs else [frame_idx] + idxs = ( + self.transform.get_data_idxs(frame_idx) + if self.transform.frame_idxs + else [frame_idx] + ) # Loop over indices, in case there's more than one for frame for idx in idxs: @@ -84,12 +119,17 @@ def add_to_scene(self, video, frame_idx): x, y = 0, 0 overlay_object = self.overlay_class( - self.data[idx], - scale=self.transform.scale) + self.data[idx], scale=self.transform.scale + ) - self._add(self.player.view.scene, overlay_object, (x,y)) + self._add(self.player.view.scene, overlay_object, (x, y)) - def _add(self, to: QtWidgets.QGraphicsScene, what: QtWidgets.QGraphicsObject, where: tuple=(0,0)): + def _add( + self, + to: QtWidgets.QGraphicsScene, + what: QtWidgets.QGraphicsObject, + where: tuple = (0, 0), + ): to.addItem(what) what.setPos(*where) @@ -97,49 +137,56 @@ def _add(self, to: QtWidgets.QGraphicsScene, what: QtWidgets.QGraphicsObject, wh def from_h5(cls, filename, dataset, input_format="channels_last", **kwargs): import h5py as h5 - with h5.File(filename, 'r') as f: + with h5.File(filename, "r") as f: frame_idxs = np.asarray(f["frame_idxs"], dtype="int") bounding_boxes = np.asarray(f["bounds"]) transform = DataTransform(frame_idxs=frame_idxs, bounding_boxes=bounding_boxes) - data_object = HDF5Data(filename, dataset, input_format=input_format, convert_range=False) + data_object = HDF5Data( + filename, dataset, input_format=input_format, convert_range=False + ) return cls(data=data_object, transform=transform, **kwargs) @classmethod def from_model(cls, filename, video, **kwargs): from sleap.nn.model import ModelOutputType - from sleap.nn.loadmodel import load_model, get_model_data + from sleap.nn.inference import InferenceModel from sleap.nn.training import TrainingJob # Load the trained model + training_job = TrainingJob.load_json(filename) + inference_model = InferenceModel(training_job) - trainingjob = TrainingJob.load_json(filename) - - input_size = (video.height//8*8, video.width//8*8, video.channels) - model_output_type = trainingjob.model.output_type - - model = load_model( - sleap_models={model_output_type:trainingjob}, - input_size=input_size, - output_types=[model_output_type]) - - model_data = get_model_data( - sleap_models={model_output_type:trainingjob}, - output_types=[model_output_type]) + size_reduction = 2 ** (inference_model.down_blocks) + input_size = ( + (video.height // size_reduction) * size_reduction, + (video.width // size_reduction) * size_reduction, + video.channels, + ) + model_output_type = training_job.model.output_type # Here we determine if the input should be scaled. If so, then # the output of the model will also be rescaled accordingly. + do_rescale = inference_model.input_scale != 1.0 - do_rescale = model_data["scale"] < 1 + # Determine how the output from the model should be scaled + img_output_scale = 1.0 # image rescaling + obj_output_scale = 1.0 # scale to pass to overlay object - # Construct the ModelData object that runs inference + if model_output_type == ModelOutputType.PART_AFFINITY_FIELD: + obj_output_scale = inference_model.output_relative_scale - data_object = ModelData(model, video, do_rescale=do_rescale) + else: + img_output_scale = inference_model.output_relative_scale - # Determine whether to use confmap or paf overlay + # Construct the ModelData object that runs inference + data_object = ModelData( + inference_model, video, do_rescale=do_rescale, output_scale=img_output_scale + ) + # Determine whether to use confmap or paf overlay from sleap.gui.overlays.confmaps import ConfMapsPlot from sleap.gui.overlays.pafs import MultiQuiverPlot @@ -153,20 +200,18 @@ def from_model(cls, filename, video, **kwargs): # This doesn't require rescaling the input, and the "scale" # will be passed to the overlay object to do its own upscaling # (at least for pafs). - - transform = DataTransform(scale=model_data["multiscale"]) + transform = DataTransform(scale=obj_output_scale) return cls( - data=data_object, - transform=transform, - overlay_class=overlay_class, - **kwargs) + data=data_object, transform=transform, overlay_class=overlay_class, **kwargs + ) + h5_colors = [ [204, 81, 81], - [127, 51, 51], [81, 204, 204], [51, 127, 127], + [127, 51, 51], [142, 204, 81], [89, 127, 51], [142, 81, 204], @@ -212,5 +257,5 @@ def from_model(cls, filename, video, **kwargs): [81, 204, 181], [51, 127, 113], [81, 181, 204], - [51, 113, 127] -] \ No newline at end of file + [51, 113, 127], +] diff --git a/sleap/gui/overlays/confmaps.py b/sleap/gui/overlays/confmaps.py index 5629cbd40..40d86abc9 100644 --- a/sleap/gui/overlays/confmaps.py +++ b/sleap/gui/overlays/confmaps.py @@ -8,24 +8,29 @@ from PySide2 import QtWidgets, QtCore, QtGui -import attr import numpy as np import qimage2ndarray -from typing import Sequence -from sleap.io.video import Video, HDF5Video -from sleap.gui.video import QtVideoPlayer from sleap.gui.overlays.base import DataOverlay, h5_colors + class ConfmapOverlay(DataOverlay): + """Overlay to show confidence maps.""" @classmethod def from_h5(cls, filename, input_format="channels_last", **kwargs): - return DataOverlay.from_h5(filename, "/confmaps", input_format, overlay_class=ConfMapsPlot, **kwargs) + """Create object with hdf5 as datasource.""" + return DataOverlay.from_h5( + filename, "/confmaps", input_format, overlay_class=ConfMapsPlot, **kwargs + ) @classmethod def from_model(cls, filename, video, **kwargs): - return DataOverlay.from_model(filename, video, overlay_class=ConfMapsPlot, **kwargs) + """Create object with live predictions from model as datasource.""" + return DataOverlay.from_model( + filename, video, overlay_class=ConfMapsPlot, **kwargs + ) + class ConfMapsPlot(QtWidgets.QGraphicsObject): """QGraphicsObject to display multiple confidence maps in a QGraphicsView. @@ -42,7 +47,9 @@ class ConfMapsPlot(QtWidgets.QGraphicsObject): When initialized, creates one child ConfMapPlot item for each channel. """ - def __init__(self, frame: np.array = None, show=None, show_box=True, *args, **kwargs): + def __init__( + self, frame: np.array = None, show=None, show_box=True, *args, **kwargs + ): super(ConfMapsPlot, self).__init__(*args, **kwargs) self.frame = frame self.show_box = show_box @@ -50,16 +57,18 @@ def __init__(self, frame: np.array = None, show=None, show_box=True, *args, **kw self.rect = QtCore.QRectF(0, 0, self.frame.shape[1], self.frame.shape[0]) if self.show_box: - QtWidgets.QGraphicsRectItem(self.rect, parent=self).setPen(QtGui.QPen("yellow")) + QtWidgets.QGraphicsRectItem(self.rect, parent=self).setPen( + QtGui.QPen("yellow") + ) for channel in range(self.frame.shape[2]): if show is None or channel in show: color_map = h5_colors[channel % len(h5_colors)] # Add QGraphicsPixmapItem as child object - ConfMapPlot(confmap=self.frame[..., channel], - color=color_map, - parent=self) + ConfMapPlot( + confmap=self.frame[..., channel], color=color_map, parent=self + ) def boundingRect(self) -> QtCore.QRectF: """Method required by Qt. @@ -71,6 +80,7 @@ def paint(self, painter, option, widget=None): """ pass + class ConfMapPlot(QtWidgets.QGraphicsPixmapItem): """QGraphicsPixmapItem object for drawing single channel of confidence map. @@ -85,7 +95,9 @@ class ConfMapPlot(QtWidgets.QGraphicsPixmapItem): In most cases this should only be called by ConfMapsPlot. """ - def __init__(self, confmap: np.array = None, color=[255, 255, 255], *args, **kwargs): + def __init__( + self, confmap: np.array = None, color=[255, 255, 255], *args, **kwargs + ): super(ConfMapPlot, self).__init__(*args, **kwargs) self.color_map = color @@ -108,16 +120,16 @@ def get_conf_image(self) -> QtGui.QImage: frame = self.confmap # Colorize single-channel overlap - if np.ptp(frame) <= 1.: + if np.ptp(frame) <= 1.0: frame_a = (frame * 255).astype(np.uint8) frame_r = (frame * self.color_map[0]).astype(np.uint8) frame_g = (frame * self.color_map[1]).astype(np.uint8) frame_b = (frame * self.color_map[2]).astype(np.uint8) else: frame_a = (frame).astype(np.uint8) - frame_r = (frame * (self.color_map[0]/255.)).astype(np.uint8) - frame_g = (frame * (self.color_map[1]/255.)).astype(np.uint8) - frame_b = (frame * (self.color_map[2]/255.)).astype(np.uint8) + frame_r = (frame * (self.color_map[0] / 255.0)).astype(np.uint8) + frame_g = (frame * (self.color_map[1] / 255.0)).astype(np.uint8) + frame_b = (frame * (self.color_map[2] / 255.0)).astype(np.uint8) frame_composite = np.dstack((frame_r, frame_g, frame_b, frame_a)) @@ -126,20 +138,29 @@ def get_conf_image(self) -> QtGui.QImage: return image + def show_confmaps_from_h5(filename, input_format="channels_last", standalone=False): + """Demo function.""" + from sleap.io.video import HDF5Video + video = HDF5Video(filename, "/box", input_format=input_format) - conf_data = HDF5Video(filename, "/confmaps", input_format=input_format, convert_range=False) + conf_data = HDF5Video( + filename, "/confmaps", input_format=input_format, convert_range=False + ) - confmaps_ = [np.clip(conf_data.get_frame(i),0,1) for i in range(conf_data.frames)] + confmaps_ = [np.clip(conf_data.get_frame(i), 0, 1) for i in range(conf_data.frames)] confmaps = np.stack(confmaps_) return demo_confmaps(confmaps=confmaps, video=video, standalone=standalone) + def demo_confmaps(confmaps, video, standalone=False, callback=None): + """Demo function.""" from PySide2 import QtWidgets from sleap.gui.video import QtVideoPlayer - if standalone: app = QtWidgets.QApplication([]) + if standalone: + app = QtWidgets.QApplication([]) win = QtVideoPlayer(video=video) win.setWindowTitle("confmaps") @@ -147,21 +168,21 @@ def demo_confmaps(confmaps, video, standalone=False, callback=None): def plot_confmaps(parent, item_idx): if parent.frame_idx < confmaps.shape[0]: - frame_conf_map = ConfMapsPlot(confmaps[parent.frame_idx,...]) + frame_conf_map = ConfMapsPlot(confmaps[parent.frame_idx, ...]) win.view.scene.addItem(frame_conf_map) win.changedPlot.connect(plot_confmaps) - if callback: win.changedPlot.connect(callback) + if callback: + win.changedPlot.connect(callback) win.plot() - if standalone: app.exec_() + if standalone: + app.exec_() return win + if __name__ == "__main__": data_path = "tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5" show_confmaps_from_h5(data_path, input_format="channels_first", standalone=True) - -# data_path = "/Users/tabris/Documents/predictions.h5" -# show_confmaps_from_h5(data_path, input_format="channels_last", standalone=True) \ No newline at end of file diff --git a/sleap/gui/overlays/instance.py b/sleap/gui/overlays/instance.py index c99322a3f..621f0e229 100644 --- a/sleap/gui/overlays/instance.py +++ b/sleap/gui/overlays/instance.py @@ -1,21 +1,31 @@ +""" +Module with overlay for showing instances. +""" import attr -from PySide2 import QtWidgets - from sleap.gui.video import QtVideoPlayer from sleap.io.dataset import Labels -from sleap.gui.overlays.tracks import TrackColorManager + @attr.s(auto_attribs=True) class InstanceOverlay: + """Class for adding instances as overlays on video frames. + + Attributes: + labels: The :class:`Labels` dataset from which to get overlay data. + player: The video player in which to show overlay. + color_predicted: Whether to show predicted instances in color ( + rather than all in gray/yellow). + """ - labels: Labels=None - player: QtVideoPlayer=None - color_manager: TrackColorManager=TrackColorManager(labels) - color_predicted: bool=False + labels: Labels = None + player: QtVideoPlayer = None + color_predicted: bool = False def add_to_scene(self, video, frame_idx): - if self.labels is None: return + """Adds overlay for frame to player scene.""" + if self.labels is None: + return lf = self.labels.find(video, frame_idx, return_new=True)[0] @@ -29,9 +39,11 @@ def add_to_scene(self, video, frame_idx): pseudo_track = len(self.labels.tracks) + count_no_track count_no_track += 1 - is_predicted = hasattr(instance,"score") + is_predicted = hasattr(instance, "score") - self.player.addInstance(instance=instance, - color=self.color_manager.get_color(pseudo_track), - predicted=is_predicted, - color_predicted=self.color_predicted) \ No newline at end of file + self.player.addInstance( + instance=instance, + color=self.player.color_manager.get_color(pseudo_track), + predicted=is_predicted, + color_predicted=self.color_predicted, + ) diff --git a/sleap/gui/overlays/pafs.py b/sleap/gui/overlays/pafs.py index a145483e6..c42234410 100644 --- a/sleap/gui/overlays/pafs.py +++ b/sleap/gui/overlays/pafs.py @@ -9,11 +9,14 @@ from sleap.gui.overlays.base import DataOverlay, h5_colors -class PafOverlay(DataOverlay): +class PafOverlay(DataOverlay): @classmethod def from_h5(cls, filename, input_format="channels_last", **kwargs): - return DataOverlay.from_h5(filename, "/pafs", input_format, overlay_class=MultiQuiverPlot, **kwargs) + return DataOverlay.from_h5( + filename, "/pafs", input_format, overlay_class=MultiQuiverPlot, **kwargs + ) + class MultiQuiverPlot(QtWidgets.QGraphicsObject): """QtWidgets.QGraphicsObject to display multiple quiver plots in a QtWidgets.QGraphicsView. @@ -33,12 +36,15 @@ class MultiQuiverPlot(QtWidgets.QGraphicsObject): When initialized, creates one child QuiverPlot item for each channel. """ - def __init__(self, - frame: np.array = None, - show: list = None, - decimation: int = 5, - scale: float = 1.0, - *args, **kwargs): + def __init__( + self, + frame: np.array = None, + show: list = None, + decimation: int = 2, + scale: float = 1.0, + *args, + **kwargs, + ): super(MultiQuiverPlot, self).__init__(*args, **kwargs) self.frame = frame self.affinity_field = [] @@ -47,23 +53,23 @@ def __init__(self, # if data range is outside [-1, 1], assume it's [-255, 255] and scale if np.ptp(self.frame) > 4: - self.frame = self.frame.astype(np.float64)/255 + self.frame = self.frame.astype(np.float64) / 255 if show is None: - self.show_list = range(self.frame.shape[2]//2) + self.show_list = range(self.frame.shape[2] // 2) else: self.show_list = show for channel in self.show_list: - if channel < self.frame.shape[-1]//2: + if channel < self.frame.shape[-1] // 2: color_map = h5_colors[channel % len(h5_colors)] aff_field_item = QuiverPlot( - field_x=self.frame[..., channel*2], - field_y=self.frame[..., channel*2+1], + field_x=self.frame[..., channel * 2], + field_y=self.frame[..., channel * 2 + 1], color=color_map, decimation=self.decimation, scale=self.scale, - parent=self - ) + parent=self, + ) self.affinity_field.append(aff_field_item) def boundingRect(self) -> QtCore.QRectF: @@ -76,6 +82,7 @@ def paint(self, painter, option, widget=None): """ pass + class QuiverPlot(QtWidgets.QGraphicsObject): """QtWidgets.QGraphicsObject for drawing single quiver plot. @@ -89,20 +96,23 @@ class QuiverPlot(QtWidgets.QGraphicsObject): None. """ - def __init__(self, - field_x: np.array = None, - field_y: np.array = None, - color=[255, 255, 255], - decimation=1, - scale=1, - *args, **kwargs): + def __init__( + self, + field_x: np.array = None, + field_y: np.array = None, + color=[255, 255, 255], + decimation=1, + scale=1, + *args, + **kwargs, + ): super(QuiverPlot, self).__init__(*args, **kwargs) self.field_x, self.field_y = None, None self.color = color self.decimation = decimation self.scale = scale - pen_width = min(4, max(.1, math.log(self.decimation, 20))) + pen_width = min(4, max(0.1, math.log(self.decimation, 20))) self.pen = QtGui.QPen(QtGui.QColor(*self.color), pen_width) self.points = [] self.rect = QtCore.QRectF() @@ -111,7 +121,7 @@ def __init__(self, self.field_x, self.field_y = field_x, field_y h, w = self.field_x.shape - h, w = int(h/self.scale), int(w/self.scale) + h, w = int(h / self.scale), int(w / self.scale) self.rect = QtCore.QRectF(0, 0, w, h) @@ -121,63 +131,66 @@ def _add_arrows(self, min_length=0.01): points = [] if self.field_x is not None and self.field_y is not None: - raw_delta_yx = np.stack((self.field_y,self.field_x),axis=-1) + raw_delta_yx = np.stack((self.field_y, self.field_x), axis=-1) - dim_0 = self.field_x.shape[0]//self.decimation*self.decimation - dim_1 = self.field_x.shape[1]//self.decimation*self.decimation + dim_0 = self.field_x.shape[0] // self.decimation * self.decimation + dim_1 = self.field_x.shape[1] // self.decimation * self.decimation - grid = np.mgrid[0:dim_0:self.decimation, 0:dim_1:self.decimation] - loc_yx = np.moveaxis(grid,0,-1) + grid = np.mgrid[0 : dim_0 : self.decimation, 0 : dim_1 : self.decimation] + loc_yx = np.moveaxis(grid, 0, -1) # Adjust by scaling factor - loc_yx = loc_yx * (1/self.scale) + loc_yx = loc_yx * (1 / self.scale) if self.decimation > 1: delta_yx = self._decimate(raw_delta_yx, self.decimation) # Shift locations to midpoint of decimation square - loc_yx += self.decimation//2 + loc_yx += self.decimation // 2 else: delta_yx = raw_delta_yx # Split into x,y matrices - loc_y, loc_x = loc_yx[...,0], loc_yx[...,1] - delta_y, delta_x = delta_yx[...,0], delta_yx[...,1] + loc_y, loc_x = loc_yx[..., 0], loc_yx[..., 1] + delta_y, delta_x = delta_yx[..., 0], delta_yx[..., 1] # Determine vector endpoint - x2 = delta_x*self.decimation + loc_x - y2 = delta_y*self.decimation + loc_y - line_length = (delta_x**2 + delta_y**2)**.5 + x2 = delta_x * self.decimation + loc_x + y2 = delta_y * self.decimation + loc_y + line_length = (delta_x ** 2 + delta_y ** 2) ** 0.5 # Determine points for arrow arrow_head_size = line_length / 4 - u_dx = np.divide(delta_x, line_length, out=np.zeros_like(delta_x), where=line_length!=0) - u_dy = np.divide(delta_y, line_length, out=np.zeros_like(delta_y), where=line_length!=0) - p1_x = x2 - u_dx*arrow_head_size - u_dy*arrow_head_size - p1_y = y2 - u_dy*arrow_head_size + u_dx*arrow_head_size + u_dx = np.divide( + delta_x, line_length, out=np.zeros_like(delta_x), where=line_length != 0 + ) + u_dy = np.divide( + delta_y, line_length, out=np.zeros_like(delta_y), where=line_length != 0 + ) + p1_x = x2 - u_dx * arrow_head_size - u_dy * arrow_head_size + p1_y = y2 - u_dy * arrow_head_size + u_dx * arrow_head_size - p2_x = x2 - u_dx*arrow_head_size + u_dy*arrow_head_size - p2_y = y2 - u_dy*arrow_head_size - u_dx*arrow_head_size + p2_x = x2 - u_dx * arrow_head_size + u_dy * arrow_head_size + p2_y = y2 - u_dy * arrow_head_size - u_dx * arrow_head_size # Build list of QPointF objects for faster drawing y_x_pairs = itertools.product( - range(delta_yx.shape[0]), - range(delta_yx.shape[1]) - ) + range(delta_yx.shape[0]), range(delta_yx.shape[1]) + ) for y, x in y_x_pairs: - x1, y1 = loc_x[y,x], loc_y[y,x] + x1, y1 = loc_x[y, x], loc_y[y, x] - if line_length[y,x] > min_length: + if line_length[y, x] > min_length: points.append((x1, y1)) - points.append((x2[y,x],y2[y,x])) - points.append((p1_x[y,x],p1_y[y,x])) - points.append((x2[y,x],y2[y,x])) - points.append((p2_x[y,x],p2_y[y,x])) - points.append((x2[y,x],y2[y,x])) - self.points = list(itertools.starmap(QtCore.QPointF,points)) - - def _decimate(self, image:np.array, box:int): + points.append((x2[y, x], y2[y, x])) + points.append((p1_x[y, x], p1_y[y, x])) + points.append((x2[y, x], y2[y, x])) + points.append((p2_x[y, x], p2_y[y, x])) + points.append((x2[y, x], y2[y, x])) + self.points = list(itertools.starmap(QtCore.QPointF, points)) + + def _decimate(self, image: np.array, box: int): height = width = box # Source: https://stackoverflow.com/questions/48482317/slice-an-image-into-tiles-using-numpy _nrows, _ncols, depth = image.shape @@ -188,21 +201,21 @@ def _decimate(self, image:np.array, box:int): ncols, _n = divmod(_ncols, width) if _m != 0 or _n != 0: # if we can't tile whole image, forget about bottom/right edges - image = image[:(nrows+1)*box,:(ncols+1)*box] + image = image[: (nrows + 1) * box, : (ncols + 1) * box] - tiles = np.lib.stride_tricks.as_strided( + tiles = np.lib.stride_tricks.as_strided( np.ravel(image), shape=(nrows, ncols, height, width, depth), strides=(height * _strides[0], width * _strides[1], *_strides), - writeable=False + writeable=False, ) # Since strides accesses the ndarray by memory, we need to swap axes if # the array is stored column-major (Fortran), which it is from h5py. if _strides[0] < _strides[1]: - tiles = np.swapaxes(tiles,0,1) + tiles = np.swapaxes(tiles, 0, 1) - return np.mean(tiles, axis=(2,3)) + return np.mean(tiles, axis=(2, 3)) def boundingRect(self) -> QtCore.QRectF: """Method called by Qt in order to determine whether object is in visible frame.""" @@ -215,19 +228,24 @@ def paint(self, painter, option, widget=None): painter.drawLines(self.points) pass + def show_pafs_from_h5(filename, input_format="channels_last", standalone=False): video = HDF5Video(filename, "/box", input_format=input_format) - paf_data = HDF5Video(filename, "/pafs", input_format=input_format, convert_range=False) + paf_data = HDF5Video( + filename, "/pafs", input_format=input_format, convert_range=False + ) pafs_ = [paf_data.get_frame(i) for i in range(paf_data.frames)] pafs = np.stack(pafs_) return demo_pafs(pafs, video, standalone=standalone) + def demo_pafs(pafs, video, decimation=4, standalone=False): from sleap.gui.video import QtVideoPlayer - if standalone: app = QtWidgets.QApplication([]) + if standalone: + app = QtWidgets.QApplication([]) win = QtVideoPlayer(video=video) win.setWindowTitle("pafs") @@ -246,47 +264,54 @@ def plot_fields(parent, i): if parent.frame_idx < pafs.shape[0]: frame_pafs = pafs[parent.frame_idx, ...] decimation = decimation_size_bar.value() - aff_fields_item = MultiQuiverPlot(frame_pafs, show=None, decimation=decimation) + aff_fields_item = MultiQuiverPlot( + frame_pafs, show=None, decimation=decimation + ) win.view.scene.addItem(aff_fields_item) win.changedPlot.connect(plot_fields) win.plot() - if standalone: app.exec_() + if standalone: + app.exec_() return win + if __name__ == "__main__": from video import * - #data_path = "training.scale=1.00,sigma=5.h5" + # data_path = "training.scale=1.00,sigma=5.h5" data_path = "tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5" - input_format="channels_first" + input_format = "channels_first" data_path = "/Volumes/fileset-mmurthy/nat/nyu-mouse/predict.h5" input_format = "channels_last" show_pafs_from_h5(data_path, input_format=input_format, standalone=True) + def foo(): vid = HDF5Video(data_path, "/box", input_format=input_format) - overlay_data = HDF5Video(data_path, "/pafs", input_format=input_format, convert_range=False) - print(f"{overlay_data.frames}, {overlay_data.height}, {overlay_data.width}, {overlay_data.channels}") + overlay_data = HDF5Video( + data_path, "/pafs", input_format=input_format, convert_range=False + ) + print( + f"{overlay_data.frames}, {overlay_data.height}, {overlay_data.width}, {overlay_data.channels}" + ) app = QtWidgets.QApplication([]) window = QtVideoPlayer(video=vid) - field_count = overlay_data.get_frame(1).shape[-1]//2 - 1 + field_count = overlay_data.get_frame(1).shape[-1] // 2 - 1 # show the first, middle, and last fields - show_fields = [0, field_count//2, field_count] + show_fields = [0, field_count // 2, field_count] field_check_groupbox = MultiCheckWidget( - count=field_count, - selected=show_fields, - title="Affinity Field Channel" - ) + count=field_count, selected=show_fields, title="Affinity Field Channel" + ) field_check_groupbox.selectionChanged.connect(window.plot) window.layout.addWidget(field_check_groupbox) @@ -301,7 +326,7 @@ def foo(): decimation_size_bar.setEnabled(True) window.layout.addWidget(decimation_size_bar) - def plot_fields(parent,i): + def plot_fields(parent, i): # build list of checked boxes to determine which affinity fields to show selected = field_check_groupbox.getSelected() # get decimation size from slider @@ -317,4 +342,4 @@ def plot_fields(parent,i): window.show() window.plot() - app.exec_() \ No newline at end of file + app.exec_() diff --git a/sleap/gui/overlays/tracks.py b/sleap/gui/overlays/tracks.py index e968d3783..d9f9b00a5 100644 --- a/sleap/gui/overlays/tracks.py +++ b/sleap/gui/overlays/tracks.py @@ -1,5 +1,8 @@ -from sleap.skeleton import Skeleton, Node -from sleap.instance import Instance, PredictedInstance, Point, LabeledFrame, Track +""" +Module that handles track-related overlays (including track color). +""" + +from sleap.instance import Track from sleap.io.dataset import Labels from sleap.io.video import Video @@ -7,17 +10,22 @@ import itertools from typing import Union -from PySide2 import QtCore, QtWidgets, QtGui +from PySide2 import QtCore, QtGui + -class TrackColorManager: - """Class to determine color to use for track. The color depends on the order of - the tracks in `Labels` object, so we need to initialize with `Labels`. +class TrackColorManager(object): + """Class to determine color to use for track. + + The color depends on the order of the tracks in `Labels` object, + so we need to initialize with `Labels`. Args: - labels: `Labels` object which contains the tracks for which we want colors + labels: The :class:`Labels` dataset which contains the tracks for + which we want colors. + palette: String with the color palette name to use. """ - def __init__(self, labels: Labels=None, palette="standard"): + def __init__(self, labels: Labels = None, palette: str = "standard"): self.labels = labels # alphabet @@ -28,86 +36,84 @@ def __init__(self, labels: Labels=None, palette="standard"): # http://colorbrewer2.org/#type=qualitative&scheme=Paired&n=12 self._palettes = { - "standard" : [ - [0, 114, 189], - [217, 83, 25], - [237, 177, 32], - [126, 47, 142], - [119, 172, 48], - [77, 190, 238], - [162, 20, 47], - ], - "five+" : [ - [228,26,28], - [55,126,184], - [77,175,74], - [152,78,163], - [255,127,0], - ], - "solarized" : [ - [181, 137, 0], - [203, 75, 22], - [220, 50, 47], - [211, 54, 130], - [108, 113, 196], - [38, 139, 210], - [42, 161, 152], - [133, 153, 0], - ], - "alphabet" : [ - [240,163,255], - [0,117,220], - [153,63,0], - [76,0,92], - [25,25,25], - [0,92,49], - [43,206,72], - [255,204,153], - [128,128,128], - [148,255,181], - [143,124,0], - [157,204,0], - [194,0,136], - [0,51,128], - [255,164,5], - [255,168,187], - [66,102,0], - [255,0,16], - [94,241,242], - [0,153,143], - [224,255,102], - [116,10,255], - [153,0,0], - [255,255,128], - [255,255,0], - [255,80,5], - ], - "twelve" : [ - [31,120,180], - [51,160,44], - [227,26,28], - [255,127,0], - [106,61,154], - [177,89,40], - [166,206,227], - [178,223,138], - [251,154,153], - [253,191,111], - [202,178,214], - [255,255,153], - ] + "standard": [ + [0, 114, 189], + [217, 83, 25], + [237, 177, 32], + [126, 47, 142], + [119, 172, 48], + [77, 190, 238], + [162, 20, 47], + ], + "five+": [ + [228, 26, 28], + [55, 126, 184], + [77, 175, 74], + [152, 78, 163], + [255, 127, 0], + ], + "solarized": [ + [181, 137, 0], + [203, 75, 22], + [220, 50, 47], + [211, 54, 130], + [108, 113, 196], + [38, 139, 210], + [42, 161, 152], + [133, 153, 0], + ], + "alphabet": [ + [240, 163, 255], + [0, 117, 220], + [153, 63, 0], + [76, 0, 92], + [25, 25, 25], + [0, 92, 49], + [43, 206, 72], + [255, 204, 153], + [128, 128, 128], + [148, 255, 181], + [143, 124, 0], + [157, 204, 0], + [194, 0, 136], + [0, 51, 128], + [255, 164, 5], + [255, 168, 187], + [66, 102, 0], + [255, 0, 16], + [94, 241, 242], + [0, 153, 143], + [224, 255, 102], + [116, 10, 255], + [153, 0, 0], + [255, 255, 128], + [255, 255, 0], + [255, 80, 5], + ], + "twelve": [ + [31, 120, 180], + [51, 160, 44], + [227, 26, 28], + [255, 127, 0], + [106, 61, 154], + [177, 89, 40], + [166, 206, 227], + [178, 223, 138], + [251, 154, 153], + [253, 191, 111], + [202, 178, 214], + [255, 255, 153], + ], } self.mode = "cycle" - self._modes = dict( - cycle=lambda i, c: i%c, - clip=lambda i, c: min(i,c-1), - ) + self._modes = dict(cycle=lambda i, c: i % c, clip=lambda i, c: min(i, c - 1)) self.set_palette(palette) @property def labels(self): + """Gets or sets labels dataset for which we are coloring tracks.""" return self._labels @labels.setter @@ -116,9 +122,11 @@ def labels(self, val): @property def palette_names(self): + """Gets list of palette names.""" return self._palettes.keys() def set_palette(self, palette): + """Sets palette (by name).""" if isinstance(palette, str): self.mode = "clip" if palette.endswith("+") else "cycle" @@ -137,39 +145,44 @@ def get_color(self, track: Union[Track, int]): Returns: (r, g, b)-tuple """ - track_idx = self.labels.tracks.index(track) if isinstance(track, Track) else track + track_idx = ( + self.labels.tracks.index(track) if isinstance(track, Track) else track + ) color_idx = self._modes[self.mode](track_idx, len(self._color_map)) color = self._color_map[color_idx] return color + @attr.s(auto_attribs=True) class TrackTrailOverlay: - """Class to show track trails. You initialize this object with both its data source - and its visual output scene, and it handles both extracting the relevant data for a - given frame and plotting it in the output. + """Class to show track trails as overlay on video frame. - Args: - labels: `Labels` object from which to get data - scene: `QGraphicsScene` in which to plot trails - trail_length (optional): maximum number of frames to include in trail + Initialize this object with both its data source and its visual output + scene, and it handles both extracting the relevant data for a given + frame and plotting it in the output. + + Attributes: + labels: The :class:`Labels` dataset from which to get overlay data. + player: The video player in which to show overlay. + trail_length: The maximum number of frames to include in trail. Usage: - After class is instatiated, call add_trails_to_scene(frame_idx) + After class is instantiated, call :meth:`add_to_scene(frame_idx)` to plot the trails in scene. """ - labels: Labels=None - scene: QtWidgets.QGraphicsScene=None - color_manager: TrackColorManager=TrackColorManager(labels) - trail_length: int=4 - show: bool=False - + labels: Labels = None + player: "QtVideoPlayer" = None + trail_length: int = 4 + show: bool = False + def get_track_trails(self, frame_selection, track: Track): """Get data needed to draw track trail. Args: - frame_selection: an interable with the `LabeledFrame`s to include in trail - track: the `Track` for which to get trail + frame_selection: an interable with the :class:`LabeledFrame` + objects to include in trail. + track: the :class:`Track` for which to get trail Returns: list of lists of (x, y) tuples @@ -201,17 +214,21 @@ def get_track_trails(self, frame_selection, track: Track): return all_trails def get_frame_selection(self, video: Video, frame_idx: int): - """Return list of `LabeledFrame`s to include in trail for specified frame.""" + """ + Return `LabeledFrame` objects to include in trail for specified frame. + """ - frame_selection = self.labels.find(video, range(0, frame_idx+1)) + frame_selection = self.labels.find(video, range(0, frame_idx + 1)) frame_selection.sort(key=lambda x: x.frame_idx) - return frame_selection[-self.trail_length:] + return frame_selection[-self.trail_length :] def get_tracks_in_frame(self, video: Video, frame_idx: int): """Return list of tracks that have instance in specified frame.""" - - tracks_in_frame = [inst.track for lf in self.labels.find(video, frame_idx) for inst in lf] + + tracks_in_frame = [ + inst.track for lf in self.labels.find(video, frame_idx) for inst in lf + ] return tracks_in_frame def add_to_scene(self, video: Video, frame_idx: int): @@ -222,7 +239,8 @@ def add_to_scene(self, video: Video, frame_idx: int): frame_idx: index of the frame to which the trail is attached """ - if not self.show: return + if not self.show: + return frame_selection = self.get_frame_selection(video, frame_idx) tracks_in_frame = self.get_tracks_in_frame(video, frame_idx) @@ -231,24 +249,83 @@ def add_to_scene(self, video: Video, frame_idx: int): trails = self.get_track_trails(frame_selection, track) - color = QtGui.QColor(*self.color_manager.get_color(track)) + color = QtGui.QColor(*self.player.color_manager.get_color(track)) pen = QtGui.QPen() pen.setCosmetic(True) for trail in trails: - half = len(trail)//2 + half = len(trail) // 2 color.setAlphaF(1) pen.setColor(color) polygon = self.map_to_qt_polygon(trail[:half]) - self.scene.addPolygon(polygon, pen) + self.player.scene.addPolygon(polygon, pen) - color.setAlphaF(.5) + color.setAlphaF(0.5) pen.setColor(color) polygon = self.map_to_qt_polygon(trail[half:]) - self.scene.addPolygon(polygon, pen) + self.player.scene.addPolygon(polygon, pen) @staticmethod def map_to_qt_polygon(point_list): """Converts a list of (x, y)-tuples to a `QPolygonF`.""" - return QtGui.QPolygonF(list(itertools.starmap(QtCore.QPointF, point_list))) \ No newline at end of file + return QtGui.QPolygonF(list(itertools.starmap(QtCore.QPointF, point_list))) + + +@attr.s(auto_attribs=True) +class TrackListOverlay: + """ + Class to show track number and names in overlay. + """ + + labels: Labels = None + player: "QtVideoPlayer" = None + text_box = None + + def add_to_scene(self, video: Video, frame_idx: int): + """Adds track list as overlay on video.""" + from sleap.gui.video import QtTextWithBackground + + html = "" + num_to_show = min(9, len(self.labels.tracks)) + + for i, track in enumerate(self.labels.tracks[:num_to_show]): + idx = i + 1 + + if html: + html += "
" + color = self.player.color_manager.get_color(track) + html_color = f"#{color[0]:02X}{color[1]:02X}{color[2]:02X}" + track_text = f"{track.name}" + if str(idx) != track.name: + track_text += f" ({idx})" + html += f"{track_text}" + + text_box = QtTextWithBackground() + text_box.setDefaultTextColor(QtGui.QColor("white")) + text_box.setHtml(html) + text_box.setOpacity(0.7) + + self.text_box = text_box + self.visible = False + + self.player.scene.addItem(self.text_box) + + @property + def visible(self): + """Gets or set whether overlay is visible.""" + if self.text_box is None: + return False + return self.text_box.isVisible() + + @visible.setter + def visible(self, val): + if self.text_box is None: + return + if val: + pos = self.player.view.mapToScene(10, 10) + if pos.x() > 0: + self.text_box.setPos(pos) + else: + self.text_box.setPos(10, 10) + self.text_box.setVisible(val) diff --git a/sleap/gui/shortcuts.py b/sleap/gui/shortcuts.py new file mode 100644 index 000000000..c4953de69 --- /dev/null +++ b/sleap/gui/shortcuts.py @@ -0,0 +1,210 @@ +""" +Gui for keyboard shortcuts. +""" +from PySide2 import QtWidgets +from PySide2.QtGui import QKeySequence + +import yaml + +from typing import Dict, List, Union +from pkg_resources import Requirement, resource_filename + + +class ShortcutDialog(QtWidgets.QDialog): + """ + Dialog window for reviewing and modifying the keyboard shortcuts. + """ + + _column_len = 13 + + def __init__(self, *args, **kwargs): + super(ShortcutDialog, self).__init__(*args, **kwargs) + + self.setWindowTitle("Keyboard Shortcuts") + self.load_shortcuts() + self.make_form() + + def accept(self): + """Triggered when form is accepted; saves the shortcuts.""" + for action, widget in self.key_widgets.items(): + self.shortcuts[action] = widget.keySequence().toString() + self.shortcuts.save() + + super(ShortcutDialog, self).accept() + + def load_shortcuts(self): + """Loads shortcuts object.""" + self.shortcuts = Shortcuts() + + def make_form(self): + """Creates the form with fields for all shortcuts.""" + self.key_widgets = dict() # dict to store QKeySequenceEdit widgets + + layout = QtWidgets.QVBoxLayout() + layout.addWidget(self.make_shortcuts_widget()) + layout.addWidget( + QtWidgets.QLabel( + "Any changes to keyboard shortcuts will not take effect " + "until you quit and re-open the application." + ) + ) + layout.addWidget(self.make_buttons_widget()) + self.setLayout(layout) + + def make_buttons_widget(self) -> QtWidgets.QDialogButtonBox: + """Makes the form buttons.""" + buttons = QtWidgets.QDialogButtonBox( + QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel + ) + buttons.accepted.connect(self.accept) + buttons.rejected.connect(self.reject) + return buttons + + def make_shortcuts_widget(self) -> QtWidgets.QWidget: + """Makes the widget will fields for all shortcuts.""" + shortcuts = self.shortcuts + + widget = QtWidgets.QWidget() + layout = QtWidgets.QHBoxLayout() + + # show shortcuts in columns + for a in range(0, len(shortcuts), self._column_len): + b = min(len(shortcuts), a + self._column_len) + column_widget = self.make_column_widget(shortcuts[a:b]) + layout.addWidget(column_widget) + widget.setLayout(layout) + return widget + + def make_column_widget(self, shortcuts: List) -> QtWidgets.QWidget: + """Makes a single column of shortcut fields. + + Args: + shortcuts: The list of shortcuts to include in this column. + + Returns: + The widget. + """ + column_widget = QtWidgets.QWidget() + column_layout = QtWidgets.QFormLayout() + for action in shortcuts: + item = QtWidgets.QKeySequenceEdit(shortcuts[action]) + column_layout.addRow(action.title(), item) + self.key_widgets[action] = item + column_widget.setLayout(column_layout) + return column_widget + + +class Shortcuts(object): + """ + Class for accessing keyboard shortcuts. + + Shortcuts are saved in `sleap/config/shortcuts.yaml` + + When instantiated, this reads in the shortcuts from the file. + """ + + _shortcuts = None + _names = ( + "new", + "open", + "save", + "save as", + "close", + "add videos", + "next video", + "prev video", + "goto frame", + "mark frame", + "goto marked", + "add instance", + "delete instance", + "delete track", + "transpose", + "select next", + "clear selection", + "goto next labeled", + "goto prev labeled", + "goto next user", + "goto next suggestion", + "goto prev suggestion", + "goto next track spawn", + "show labels", + "show edges", + "show trails", + "color predicted", + "fit", + "learning", + "export clip", + "delete clip", + "delete area", + ) + + def __init__(self): + shortcut_yaml = resource_filename( + Requirement.parse("sleap"), "sleap/config/shortcuts.yaml" + ) + with open(shortcut_yaml, "r") as f: + shortcuts = yaml.load(f, Loader=yaml.SafeLoader) + + for action in shortcuts: + key_string = shortcuts.get(action, None) + key_string = "" if key_string is None else key_string + + try: + shortcuts[action] = eval(key_string) + except: + shortcuts[action] = QKeySequence.fromString(key_string) + + self._shortcuts = shortcuts + + def save(self): + """Saves all shortcuts to shortcut file.""" + shortcut_yaml = resource_filename( + Requirement.parse("sleap"), "sleap/config/shortcuts.yaml" + ) + with open(shortcut_yaml, "w") as f: + yaml.dump(self._shortcuts, f) + + def __getitem__(self, idx: Union[slice, int, str]) -> Union[str, Dict[str, str]]: + """ + Returns shortcut value, accessed by range, index, or key. + + Args: + idx: Index (range, int, or str) of shortcut to access. + + Returns: + If idx is int or string, then return value is the shortcut string. + If idx is range, then return value is dictionary in which keys + are shortcut name and value are shortcut strings. + """ + if isinstance(idx, slice): + # dict with names and values + return {self._names[i]: self[i] for i in range(*idx.indices(len(self)))} + elif isinstance(idx, int): + # value + idx = self._names[idx] + return self[idx] + else: + # value + if idx in self._shortcuts: + return self._shortcuts[idx] + return "" + + def __setitem__(self, idx: Union[str, int], val: str): + """Sets shortcut by index.""" + if type(idx) == int: + idx = self._names[idx] + self[idx] = val + else: + self._shortcuts[idx] = val + + def __len__(self): + """Returns number of shortcuts.""" + return len(self._names) + + +if __name__ == "__main__": + app = QtWidgets.QApplication() + win = ShortcutDialog() + win.show() + app.exec_() diff --git a/sleap/gui/slider.py b/sleap/gui/slider.py index 24333215d..7d2720b06 100644 --- a/sleap/gui/slider.py +++ b/sleap/gui/slider.py @@ -2,18 +2,77 @@ Drop-in replacement for QSlider with additional features. """ -from PySide2.QtWidgets import QApplication, QWidget, QLayout, QAbstractSlider -from PySide2.QtWidgets import QGraphicsView, QGraphicsScene, QGraphicsItem -from PySide2.QtWidgets import QSizePolicy, QLabel, QGraphicsRectItem -from PySide2.QtGui import QPainter, QPen, QBrush, QColor, QKeyEvent -from PySide2.QtCore import Qt, Signal, QRect, QRectF +from PySide2 import QtCore, QtWidgets +from PySide2.QtGui import QPen, QBrush, QColor, QKeyEvent, QPolygonF, QPainterPath from sleap.gui.overlays.tracks import TrackColorManager -from operator import itemgetter -from itertools import groupby +import attr +import itertools +import numpy as np +from typing import Dict, Iterable, List, Optional, Union -class VideoSlider(QGraphicsView): + +@attr.s(auto_attribs=True, cmp=False) +class SliderMark: + """ + Class to hold data for an individual mark on the slider. + + Attributes: + type: Type of the mark, options are: + * "simple" (single value) + * "filled" (single value) + * "open" (single value) + * "predicted" (single value) + * "track" (range of values) + val: Beginning of mark range + end_val: End of mark range (for "track" marks) + row: The row that the mark goes in; used for tracks. + color: Color of mark, can be string or (r, g, b) tuple. + filled: Whether the mark is shown filled (solid color). + """ + + type: str + val: float + end_val: float = None + row: int = None + track: "Track" = None + _color: Union[tuple, str] = "black" + + @property + def color(self): + """Returns color of mark.""" + colors = dict(simple="black", filled="blue", open="blue", predicted="red") + + if self.type in colors: + return colors[self.type] + else: + return self._color + + @color.setter + def color(self, val): + """Sets color of mark.""" + self._color = val + + @property + def QColor(self): + """Returns color of mark as `QColor`.""" + c = self.color + if type(c) == str: + return QColor(c) + else: + return QColor(*c) + + @property + def filled(self): + """Returns whether mark is filled or open.""" + if self.type == "open": + return False + else: + return True + + +class VideoSlider(QtWidgets.QGraphicsView): """Drop-in replacement for QSlider with additional features. Args: @@ -25,50 +84,79 @@ class VideoSlider(QGraphicsView): this can be either * list of values to mark * list of (track, value)-tuples to mark + color_manager: A :class:`TrackColorManager` which determines the + color to use for "track"-type marks + + Signals: + mousePressed: triggered on Qt event + mouseMoved: triggered on Qt event + mouseReleased: triggered on Qt event + keyPress: triggered on Qt event + keyReleased: triggered on Qt event + valueChanged: triggered when value of slider changes + selectionChanged: triggered when slider range selection changes + heightUpdated: triggered when the height of slider changes """ - mousePressed = Signal(float, float) - mouseMoved = Signal(float, float) - mouseReleased = Signal(float, float) - keyPress = Signal(QKeyEvent) - keyRelease = Signal(QKeyEvent) - valueChanged = Signal(int) - selectionChanged = Signal(int, int) - updatedTracks = Signal() - - def __init__(self, orientation=-1, min=0, max=100, val=0, - marks=None, tracks=0, - color_manager=None, - *args, **kwargs): + mousePressed = QtCore.Signal(float, float) + mouseMoved = QtCore.Signal(float, float) + mouseReleased = QtCore.Signal(float, float) + keyPress = QtCore.Signal(QKeyEvent) + keyRelease = QtCore.Signal(QKeyEvent) + valueChanged = QtCore.Signal(int) + selectionChanged = QtCore.Signal(int, int) + heightUpdated = QtCore.Signal() + + def __init__( + self, + orientation=-1, # for compatibility with QSlider + min=0, + max=100, + val=0, + marks=None, + color_manager: Optional[TrackColorManager] = None, + *args, + **kwargs + ): super(VideoSlider, self).__init__(*args, **kwargs) - self.scene = QGraphicsScene() + self.scene = QtWidgets.QGraphicsScene() self.setScene(self.scene) - self.setAlignment(Qt.AlignLeft | Qt.AlignTop) + self.setAlignment(QtCore.Qt.AlignLeft | QtCore.Qt.AlignTop) - self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) - self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) - self.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) # ScrollBarAsNeeded + self.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed) + self.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff) + self.setVerticalScrollBarPolicy( + QtCore.Qt.ScrollBarAlwaysOff + ) # ScrollBarAsNeeded - self._color_manager = color_manager or TrackColorManager() + self._color_manager = color_manager + self._track_rows = 0 self._track_height = 3 + self._header_height = 0 + self._min_height = 19 + self._header_height - height = 19 - slider_rect = QRect(0, 0, 200, height-3) - handle_width = 6 - handle_rect = QRect(0, 1, handle_width, slider_rect.height()-2) - self.setMinimumHeight(height) - self.setMaximumHeight(height) - - self.slider = self.scene.addRect(slider_rect) - self.slider.setPen(QPen(QColor("black"))) + # Add border rect + outline_rect = QtCore.QRect(0, 0, 200, self._min_height - 3) + self.outlineBox = self.scene.addRect(outline_rect) + self.outlineBox.setPen(QPen(QColor("black"))) + # Add drag handle rect + handle_width = 6 + handle_rect = QtCore.QRect( + 0, self._handleTop(), handle_width, self._handleHeight() + ) + self.setMinimumHeight(self._min_height) + self.setMaximumHeight(self._min_height) self.handle = self.scene.addRect(handle_rect) self.handle.setPen(QPen(QColor(80, 80, 80))) self.handle.setBrush(QColor(128, 128, 128, 128)) - self.select_box = self.scene.addRect(QRect(0, 1, 0, slider_rect.height()-2)) + # Add (hidden) rect to highlight selection + self.select_box = self.scene.addRect( + QtCore.QRect(0, 1, 0, outline_rect.height() - 2) + ) self.select_box.setPen(QPen(QColor(80, 80, 255))) self.select_box.setBrush(QColor(80, 80, 255, 128)) self.select_box.hide() @@ -82,113 +170,188 @@ def __init__(self, orientation=-1, min=0, max=100, val=0, self.setValue(val) self.setMarks(marks) - def setTracksFromLabels(self, labels, video): + pen = QPen(QColor(80, 80, 255), 0.5) + pen.setCosmetic(True) + self.poly = self.scene.addPath(QPainterPath(), pen, self.select_box.brush()) + self.headerSeries = dict() + self.drawHeader() + + def _pointsToPath(self, points: List[QtCore.QPointF]) -> QPainterPath: + """Converts list of `QtCore.QPointF` objects to a `QPainterPath`.""" + path = QPainterPath() + path.addPolygon(QPolygonF(points)) + return path + + def setTracksFromLabels(self, labels: "Labels", video: "Video"): """Set slider marks using track information from `Labels` object. Note that this is the only method coupled to a SLEAP object. Args: - labels: the `labels` with tracks and labeled_frames + labels: the dataset with tracks and labeled frames video: the video for which to show marks + + Returns: + None """ + + if self._color_manager is None: + self._color_manager = TrackColorManager(labels=labels) + lfs = labels.find(video) slider_marks = [] - track_idx = 0 + track_row = 0 # Add marks with track track_occupancy = labels.get_track_occupany(video) for track in labels.tracks: -# track_idx = labels.tracks.index(track) if track in track_occupancy and not track_occupancy[track].is_empty: for occupancy_range in track_occupancy[track].list: - slider_marks.append((track_idx, *occupancy_range)) - track_idx += 1 + slider_marks.append( + SliderMark( + "track", + val=occupancy_range[0], + end_val=occupancy_range[1], + row=track_row, + color=self._color_manager.get_color(track), + ) + ) + track_row += 1 # Add marks without track if None in track_occupancy: for occupancy_range in track_occupancy[None].list: - slider_marks.extend(range(*occupancy_range)) + for val in range(*occupancy_range): + slider_marks.append(SliderMark("simple", val=val)) # list of frame_idx for simple markers for labeled frames labeled_marks = [lf.frame_idx for lf in lfs] user_labeled = [lf.frame_idx for lf in lfs if len(lf.user_instances)] - # "f" for suggestions with instances and "o" for those without - # "f" means "filled", "o" means "open" - # "p" for suggestions with only predicted instances - def mark_type(frame): - if frame in user_labeled: - return "f" - elif frame in labeled_marks: - return "p" + + for frame_idx in labels.get_video_suggestions(video): + if frame_idx in user_labeled: + mark_type = "filled" + elif frame_idx in labeled_marks: + mark_type = "predicted" else: - return "o" - # list of (type, frame) tuples for suggestions - suggestion_marks = [(mark_type(frame_idx), frame_idx) - for frame_idx in labels.get_video_suggestions(video)] - # combine marks for labeled frame and marks for suggested frames - slider_marks.extend(suggestion_marks) - - self.setTracks(track_idx) + mark_type = "open" + slider_marks.append(SliderMark(mark_type, val=frame_idx)) + + self.setTracks(track_row) # total number of tracks to show self.setMarks(slider_marks) - self.updatedTracks.emit() + def setHeaderSeries(self, series: Optional[Dict[int, float]] = None): + """Show header graph with specified series. - def setTracks(self, tracks): + Args: + series: {frame number: series value} dict. + Returns: + None. + """ + self.headerSeries = [] if series is None else series + self._header_height = 30 + self.drawHeader() + self.updateHeight() + + def clearHeader(self): + """Remove header graph from slider.""" + self.headerSeries = [] + self._header_height = 0 + self.updateHeight() + + def setTracks(self, track_rows): """Set the number of tracks to show in slider. Args: - tracks: the number of tracks to show + track_rows: the number of tracks to show """ + self._track_rows = track_rows + self.updateHeight() + + def updateHeight(self): + """Update the height of the slider.""" + + tracks = self._track_rows if tracks == 0: - min_height = max_height = 19 + min_height = self._min_height + max_height = self._min_height else: - min_height = max(19, 8 + (self._track_height * min(tracks, 20))) - max_height = max(19, 8 + (self._track_height * tracks)) + # Start with padding height + extra_height = 8 + self._header_height + min_height = extra_height + max_height = extra_height + + # Add height for tracks + min_height += self._track_height * min(tracks, 20) + max_height += self._track_height * tracks + + # Make sure min/max height is at least 19, even if few tracks + min_height = max(self._min_height, min_height) + max_height = max(self._min_height, max_height) self.setMaximumHeight(max_height) self.setMinimumHeight(min_height) + + # Redraw all marks with new height and y position + marks = self.getMarks() + self.setMarks(marks) + self.resizeEvent() + self.heightUpdated.emit() + + def _toPos(self, val: float, center=False) -> float: + """ + Converts slider value to x position on slider. + + Args: + val: The slider value. + center: Whether to offset by half the width of drag handle, + so that plotted location will light up with center of handle. - def _toPos(self, val, center=False): + Returns: + x position. + """ x = val x -= self._val_min - x /= max(1, self._val_max-self._val_min) + x /= max(1, self._val_max - self._val_min) x *= self._sliderWidth() if center: - x += self.handle.rect().width()/2. + x += self.handle.rect().width() / 2.0 return x - def _toVal(self, x, center=False): + def _toVal(self, x: float, center=False) -> float: + """Converts x position to slider value.""" val = x val /= self._sliderWidth() - val *= max(1, self._val_max-self._val_min) + val *= max(1, self._val_max - self._val_min) val += self._val_min val = round(val) return val - def _sliderWidth(self): - return self.slider.rect().width()-self.handle.rect().width() + def _sliderWidth(self) -> float: + """Returns visual width of slider.""" + return self.outlineBox.rect().width() - self.handle.rect().width() - def value(self): - """Get value of slider.""" + def value(self) -> float: + """Returns value of slider.""" return self._val_main - def setValue(self, val): - """Set value of slider.""" + def setValue(self, val: float) -> float: + """Sets value of slider.""" self._val_main = val x = self._toPos(val) self.handle.setPos(x, 0) - def setMinimum(self, min): - """Set minimum value for slider.""" + def setMinimum(self, min: float) -> float: + """Sets minimum value for slider.""" self._val_min = min - def setMaximum(self, max): - """Set maximum value for slider.""" + def setMaximum(self, max: float) -> float: + """Sets maximum value for slider.""" self._val_max = max - def setEnabled(self, val): + def setEnabled(self, val: float) -> float: """Set whether the slider is enabled.""" self._enabled = val @@ -197,26 +360,31 @@ def enabled(self): return self._enabled def clearSelection(self): - """Clear selection endpoints.""" + """Clears selection endpoints.""" self._selection = [] self.select_box.hide() def startSelection(self, val): - """Add initial selection endpoint. + """Adds initial selection endpoint. + + Called when user starts dragging to select range in slider. Args: val: value of endpoint """ self._selection.append(val) - def endSelection(self, val, update=False): + def endSelection(self, val, update: bool = False): """Add final selection endpoint. + Called during or after the user is dragging to select range. + Args: val: value of endpoint + update: """ # If we want to update endpoint and there's already one, remove it - if update and len(self._selection)%2==0: + if update and len(self._selection) % 2 == 0: self._selection.pop() # Add the selection endpoint self._selection.append(val) @@ -229,165 +397,163 @@ def endSelection(self, val, update=False): self.selectionChanged.emit(*self.getSelection()) def hasSelection(self) -> bool: - """Return True if a clip is selected, False otherwise.""" + """Returns True if a clip is selected, False otherwise.""" a, b = self.getSelection() return a < b def getSelection(self): - """Return start and end value of current selection endpoints.""" + """Returns start and end value of current selection endpoints.""" a, b = 0, 0 - if len(self._selection)%2 == 0 and len(self._selection) > 0: + if len(self._selection) % 2 == 0 and len(self._selection) > 0: a, b = self._selection[-2:] start = min(a, b) end = max(a, b) return start, end - def drawSelection(self, a, b): - """Draw selection box on slider. + def drawSelection(self, a: float, b: float): + """Draws selection box on slider. Args: a: one endpoint value b: other endpoint value + + Returns: + None. """ start = min(a, b) end = max(a, b) start_pos = self._toPos(start, center=True) end_pos = self._toPos(end, center=True) - selection_rect = QRect(start_pos, 1, - end_pos-start_pos, self.slider.rect().height()-2) + selection_rect = QtCore.QRect( + start_pos, 1, end_pos - start_pos, self.outlineBox.rect().height() - 2 + ) self.select_box.setRect(selection_rect) self.select_box.show() - def moveSelectionAnchor(self, x, y): - """Move selection anchor in response to mouse position. + def moveSelectionAnchor(self, x: float, y: float): + """ + Moves selection anchor in response to mouse position. Args: x: x position of mouse y: y position of mouse + + Returns: + None. """ x = max(x, 0) - x = min(x, self.slider.rect().width()) + x = min(x, self.outlineBox.rect().width()) anchor_val = self._toVal(x, center=True) - if len(self._selection)%2 == 0: + if len(self._selection) % 2 == 0: self.startSelection(anchor_val) self.drawSelection(anchor_val, self._selection[-1]) def releaseSelectionAnchor(self, x, y): - """Finish selection in response to mouse release. + """ + Finishes selection in response to mouse release. Args: x: x position of mouse y: y position of mouse + + Returns: + None. """ x = max(x, 0) - x = min(x, self.slider.rect().width()) + x = min(x, self.outlineBox.rect().width()) anchor_val = self._toVal(x) self.endSelection(anchor_val) def clearMarks(self): - """Clear all marked values for slider.""" + """Clears all marked values for slider.""" if hasattr(self, "_mark_items"): for item in self._mark_items.values(): self.scene.removeItem(item) - self._marks = set() # holds mark position - self._mark_items = dict() # holds visual Qt object for plotting mark + self._marks = set() # holds mark position + self._mark_items = dict() # holds visual Qt object for plotting mark - def setMarks(self, marks): - """Set all marked values for the slider. + def setMarks(self, marks: Iterable[Union[SliderMark, int]]): + """Sets all marked values for the slider. Args: marks: iterable with all values to mark + + Returns: + None. """ self.clearMarks() if marks is not None: for mark in marks: + if not isinstance(mark, SliderMark): + mark = SliderMark("simple", mark) self.addMark(mark, update=False) self.updatePos() def getMarks(self): - """Return list of marks. - - Each mark is either val or (track, val)-tuple. - """ + """Returns list of marks.""" return self._marks - def addMark(self, new_mark, update=True): - """Add a marked value to the slider. + def addMark(self, new_mark: SliderMark, update: bool = True): + """Adds a marked value to the slider. Args: new_mark: value to mark + update: Whether to redraw slider with new mark. + + Returns: + None. """ # check if mark is within slider range - if self._mark_val(new_mark) > self._val_max: return - if self._mark_val(new_mark) < self._val_min: return + if new_mark.val > self._val_max: + return + if new_mark.val < self._val_min: + return self._marks.add(new_mark) + v_top_pad = 3 + self._header_height + v_bottom_pad = 3 + width = 0 - filled = True - if type(new_mark) == tuple: - if type(new_mark[0]) == int: - # colored track if mark has format: (track_number, start_frame_idx, end_frame_idx) - track = new_mark[0] - v_offset = 3 + (self._track_height * track) - height = 1 - color = QColor(*self._color_manager.get_color(track)) - else: - # rect (open/filled) if format: ("o", frame_idx) or ("f", frame_idx) - # ("p", frame_idx) when only predicted instances on frame - mark_type = new_mark[0] - v_offset = 3 - height = self.slider.rect().height()-6 - if mark_type == "o": - width = 2 - filled = False - color = QColor("blue") - elif mark_type == "f": - width = 2 - color = QColor("blue") - elif mark_type == "p": - width = 0 - color = QColor("red") + if new_mark.type == "track": + v_offset = v_top_pad + (self._track_height * new_mark.row) + height = 1 else: - # line if mark has format: frame_idx - v_offset = 3 - height = self.slider.rect().height()-6 - color = QColor("black") + v_offset = v_top_pad + height = self.outlineBox.rect().height() - (v_offset + v_bottom_pad) + + width = 2 if new_mark.type in ("open", "filled") else 0 - pen = QPen(color, .5) + color = new_mark.QColor + pen = QPen(color, 0.5) pen.setCosmetic(True) - brush = QBrush(color) if filled else QBrush() + brush = QBrush(color) if new_mark.filled else QBrush() - line = self.scene.addRect(-width//2, v_offset, width, height, - pen, brush) + line = self.scene.addRect(-width // 2, v_offset, width, height, pen, brush) self._mark_items[new_mark] = line - if update: self.updatePos() - - def _mark_val(self, mark): - return mark[1] if type(mark) == tuple else mark + if update: + self.updatePos() def updatePos(self): - """Update the visual position of handle and slider annotations.""" + """Update the visual x position of handle and slider annotations.""" x = self._toPos(self.value()) self.handle.setPos(x, 0) + for mark in self._mark_items.keys(): + width = 0 - if type(mark) == tuple: - in_track = True - v = mark[1] - if type(mark[0]) == int: - width_in_frames = mark[2] - mark[1] - width = max(2, self._toPos(width_in_frames)) - elif mark[0] == "o": - width = 2 - else: - in_track = False - v = mark - x = self._toPos(v, center=True) + if mark.type == "track": + width_in_frames = mark.end_val - mark.val + width = max(2, self._toPos(width_in_frames)) + + elif mark.type in ("open", "filled"): + width = 2 + + x = self._toPos(mark.val, center=True) self._mark_items[mark].setPos(x, 0) rect = self._mark_items[mark].rect() @@ -395,6 +561,46 @@ def updatePos(self): self._mark_items[mark].setRect(rect) + def drawHeader(self): + """Draw the header graph.""" + if len(self.headerSeries) == 0 or self._header_height == 0: + self.poly.setPath(QPainterPath()) + return + + step = max(self.headerSeries.keys()) // int(self._sliderWidth()) + step = max(step, 1) + count = max(self.headerSeries.keys()) // step * step + + sampled = np.full((count), 0.0) + for key, val in self.headerSeries.items(): + if key < count: + sampled[key] = val + sampled = np.max(sampled.reshape(count // step, step), axis=1) + series = {i * step: sampled[i] for i in range(count // step)} + + series_min = np.min(sampled) - 1 + series_max = np.max(sampled) + series_scale = (self._header_height - 5) / (series_max - series_min) + + def toYPos(val): + return self._header_height - ((val - series_min) * series_scale) + + step_chart = False # use steps rather than smooth line + + points = [] + points.append((self._toPos(0, center=True), toYPos(series_min))) + for idx, val in series.items(): + points.append((self._toPos(idx, center=True), toYPos(val))) + if step_chart: + points.append((self._toPos(idx + step, center=True), toYPos(val))) + points.append( + (self._toPos(max(series.keys()) + 1, center=True), toYPos(series_min)) + ) + + # Convert to list of QtCore.QPointF objects + points = list(itertools.starmap(QtCore.QPointF, points)) + self.poly.setPath(self._pointsToPath(points)) + def moveHandle(self, x, y): """Move handle in response to mouse position. @@ -404,20 +610,21 @@ def moveHandle(self, x, y): x: x position of mouse y: y position of mouse """ - x -= self.handle.rect().width()/2. + x -= self.handle.rect().width() / 2.0 x = max(x, 0) - x = min(x, self.slider.rect().width()-self.handle.rect().width()) + x = min(x, self.outlineBox.rect().width() - self.handle.rect().width()) val = self._toVal(x) # snap to nearby mark within handle - mark_vals = [self._mark_val(mark) for mark in self._marks] - handle_left = self._toVal(x - self.handle.rect().width()/2) - handle_right = self._toVal(x + self.handle.rect().width()/2) - marks_in_handle = [mark for mark in mark_vals - if handle_left < mark < handle_right] + mark_vals = [mark.val for mark in self._marks] + handle_left = self._toVal(x - self.handle.rect().width() / 2) + handle_right = self._toVal(x + self.handle.rect().width() / 2) + marks_in_handle = [ + mark for mark in mark_vals if handle_left < mark < handle_right + ] if marks_in_handle: - marks_in_handle.sort(key=lambda m: (abs(m-val), m>val)) + marks_in_handle.sort(key=lambda m: (abs(m - val), m > val)) val = marks_in_handle[0] old = self.value() @@ -432,25 +639,53 @@ def resizeEvent(self, event=None): Args: event """ - height = self.size().height() - slider_rect = self.slider.rect() + outline_rect = self.outlineBox.rect() handle_rect = self.handle.rect() select_box_rect = self.select_box.rect() - slider_rect.setHeight(height-3) - if event is not None: slider_rect.setWidth(event.size().width()-1) - handle_rect.setHeight(slider_rect.height()-2) - select_box_rect.setHeight(slider_rect.height()-2) + outline_rect.setHeight(height - 3) + if event is not None: + outline_rect.setWidth(event.size().width() - 1) + self.outlineBox.setRect(outline_rect) - self.slider.setRect(slider_rect) + handle_rect.setTop(self._handleTop()) + handle_rect.setHeight(self._handleHeight()) self.handle.setRect(handle_rect) + + select_box_rect.setHeight(self._handleHeight()) self.select_box.setRect(select_box_rect) self.updatePos() + self.drawHeader() super(VideoSlider, self).resizeEvent(event) + def _handleTop(self) -> float: + """Returns y position of top of handle (i.e., header height).""" + return 1 + self._header_height + + def _handleHeight(self, outline_rect=None) -> float: + """ + Returns visual height of handle. + + Args: + outline_rect: The rect of the outline box for the slider. This + is only required when calling during initialization (when the + outline box doesn't yet exist). + + Returns: + Height of handle in pixels. + """ + if outline_rect is None: + outline_rect = self.outlineBox.rect() + + handle_bottom_offset = 1 + handle_height = outline_rect.height() - ( + self._handleTop() + handle_bottom_offset + ) + return handle_height + def mousePressEvent(self, event): """Override method to move handle for mouse press/drag. @@ -460,20 +695,22 @@ def mousePressEvent(self, event): scenePos = self.mapToScene(event.pos()) # Do nothing if not enabled - if not self.enabled(): return + if not self.enabled(): + return # Do nothing if click outside slider area - if not self.slider.rect().contains(scenePos): return + if not self.outlineBox.rect().contains(scenePos): + return move_function = None release_function = None - if event.modifiers() == Qt.ShiftModifier: + if event.modifiers() == QtCore.Qt.ShiftModifier: move_function = self.moveSelectionAnchor release_function = self.releaseSelectionAnchor self.clearSelection() - elif event.modifiers() == Qt.NoModifier: + elif event.modifiers() == QtCore.Qt.NoModifier: move_function = self.moveHandle release_function = None @@ -513,25 +750,25 @@ def keyReleaseEvent(self, event): self.keyRelease.emit(event) event.accept() - def boundingRect(self) -> QRectF: + def boundingRect(self) -> QtCore.QRectF: """Method required by Qt.""" - return self.slider.rect() + return self.outlineBox.rect() def paint(self, *args, **kwargs): """Method required by Qt.""" super(VideoSlider, self).paint(*args, **kwargs) + if __name__ == "__main__": - app = QApplication([]) + app = QtWidgets.QApplication([]) window = VideoSlider( - min=0, max=20, val=15, - marks=(10,15)#((0,10),(0,15),(1,10),(1,11),(2,12)), tracks=3 - ) - window.setTracks(5) -# mark_positions = ((0,10),(0,15),(1,10),(1,11),(2,12),(3,12),(3,13),(3,14),(4,15),(4,16),(4,21)) - mark_positions = [("o",i) for i in range(3,15,4)] + [("f",18)] - window.setMarks(mark_positions) + min=0, + max=20, + val=15, + marks=(10, 15), # ((0,10),(0,15),(1,10),(1,11),(2,12)), tracks=3 + ) + window.valueChanged.connect(lambda x: print(x)) window.show() diff --git a/sleap/gui/suggestions.py b/sleap/gui/suggestions.py index 6287c3a9c..6189fae6a 100644 --- a/sleap/gui/suggestions.py +++ b/sleap/gui/suggestions.py @@ -1,3 +1,7 @@ +""" +Module for generating lists of suggested frames (for labeling or reviewing). +""" + import numpy as np import itertools @@ -7,37 +11,55 @@ import cv2 +from typing import List, Tuple + from sleap.io.video import Video + class VideoFrameSuggestions: + """ + Class for generating lists of suggested frames. + + Implements various algorithms as methods: + * strides + * random + * pca_cluster + * brisk + * proofreading + + Each of algorithm method should accept `video`; other parameters will be + passed from the `params` dict given to :meth:`suggest`. - rescale=True - rescale_below=512 + """ + + rescale = True + rescale_below = 512 @classmethod - def suggest(cls, video:Video, params:dict, labels: 'Labels'=None) -> list: + def suggest(cls, video: Video, params: dict, labels: "Labels" = None) -> List[int]: """ - This is the main entry point. + This is the main entry point for generating lists of suggested frames. Args: - video: a `Video` object for which we're generating suggestions - params: a dict with all params to control how we generate suggestions - * minimally this will have a `method` corresponding to a method in class - labels: a `Labels` object + video: A `Video` object for which we're generating suggestions. + params: A dictionary with all params to control how we generate + suggestions, minimally this will have a "method" key with + the name of one of the class methods. + labels: A `Labels` object for which we are generating suggestions. Returns: - list of frame suggestions + List of suggested frame indices. """ # map from method param value to corresponding class method method_functions = dict( - strides=cls.strides, - random=cls.random, - pca=cls.pca_cluster, - hog=cls.hog, - brisk=cls.brisk, - proofreading=cls.proofreading - ) + strides=cls.strides, + random=cls.random, + pca=cls.pca_cluster, + # hog=cls.hog, + brisk=cls.brisk, + proofreading=cls.proofreading, + ) method = params["method"] if method_functions.get(method, None) is not None: @@ -49,12 +71,14 @@ def suggest(cls, video:Video, params:dict, labels: 'Labels'=None) -> list: @classmethod def strides(cls, video, per_video=20, **kwargs): - suggestions = list(range(0, video.frames, video.frames//per_video)) + """Method to generate suggestions by taking strides through video.""" + suggestions = list(range(0, video.frames, video.frames // per_video)) suggestions = suggestions[:per_video] return suggestions @classmethod def random(cls, video, per_video=20, **kwargs): + """Method to generate suggestions by taking random frames in video.""" import random suggestions = random.sample(range(video.frames), per_video) @@ -62,51 +86,33 @@ def random(cls, video, per_video=20, **kwargs): @classmethod def pca_cluster(cls, video, initial_samples, **kwargs): - - sample_step = video.frames//initial_samples + """Method to generate suggestions by using PCA clusters.""" + sample_step = video.frames // initial_samples feature_stack, frame_idx_map = cls.frame_feature_stack(video, sample_step) result = cls.feature_stack_to_suggestions( - feature_stack, frame_idx_map, **kwargs) + feature_stack, frame_idx_map, **kwargs + ) return result @classmethod def brisk(cls, video, initial_samples, **kwargs): - - sample_step = video.frames//initial_samples + """Method to generate suggestions using PCA on Brisk features.""" + sample_step = video.frames // initial_samples feature_stack, frame_idx_map = cls.brisk_feature_stack(video, sample_step) result = cls.feature_stack_to_suggestions( - feature_stack, frame_idx_map, **kwargs) - - return result - - @classmethod - def hog( - cls, video, - clusters=5, - per_cluster=5, - sample_step=5, - pca_components=50, - interleave=True, - **kwargs): - - feature_stack, frame_idx_map = cls.hog_feature_stack(video, sample_step) - - result = cls.feature_stack_to_suggestions( - feature_stack, frame_idx_map, - clusters=clusters, per_cluster=per_cluster, - pca_components=pca_components, - interleave=interleave, - **kwargs) + feature_stack, frame_idx_map, **kwargs + ) return result @classmethod def proofreading( - cls, video: Video, labels: 'Labels', score_limit, instance_limit, **kwargs): - + cls, video: Video, labels: "Labels", score_limit, instance_limit, **kwargs + ): + """Method to generate suggestions for proofreading.""" score_limit = float(score_limit) instance_limit = int(instance_limit) @@ -124,14 +130,14 @@ def proofreading( if len(frame_scores) > instance_limit: frame_scores = sorted(frame_scores)[:instance_limit] # Add to matrix - scores[i,:len(frame_scores)] = frame_scores + scores[i, : len(frame_scores)] = frame_scores idxs[i] = lf.frame_idx # Find instances below score of low_instances = np.nansum(scores < score_limit, axis=1) # Find all the frames with at least low scoring instances - result = list(idxs[low_instances >= instance_limit]) + result = idxs[low_instances >= instance_limit].tolist() return result @@ -139,8 +145,11 @@ def proofreading( # These are specific to the suggestion method @classmethod - def frame_feature_stack(cls, video:Video, sample_step:int = 5) -> tuple: - sample_count = video.frames//sample_step + def frame_feature_stack( + cls, video: Video, sample_step: int = 5 + ) -> Tuple[np.ndarray, List[int]]: + """Generates matrix of sampled video frame images.""" + sample_count = video.frames // sample_step factor = cls.get_scale_factor(video) @@ -151,8 +160,8 @@ def frame_feature_stack(cls, video:Video, sample_step:int = 5) -> tuple: frame_idx = i * sample_step img = video[frame_idx].squeeze() - multichannel = (video.channels > 1) - img = rescale(img, scale=.5, anti_aliasing=True, multichannel=multichannel) + multichannel = video.channels > 1 + img = rescale(img, scale=0.5, anti_aliasing=True, multichannel=multichannel) flat_img = img.flatten() @@ -163,7 +172,10 @@ def frame_feature_stack(cls, video:Video, sample_step:int = 5) -> tuple: return (flat_stack, frame_idx_map) @classmethod - def brisk_feature_stack(cls, video:Video, sample_step:int = 5) -> tuple: + def brisk_feature_stack( + cls, video: Video, sample_step: int = 5 + ) -> Tuple[np.ndarray, List[int]]: + """Generates Brisk features from sampled video frames.""" brisk = cv2.BRISK_create() factor = cls.get_scale_factor(video) @@ -185,86 +197,80 @@ def brisk_feature_stack(cls, video:Video, sample_step:int = 5) -> tuple: return (feature_stack, frame_idx_map) - @classmethod - def hog_feature_stack(cls, video:Video, sample_step:int = 5) -> tuple: - sample_count = video.frames//sample_step - - hog = cv2.HOGDescriptor() - - factor = cls.get_scale_factor(video) - first_hog = hog.compute(cls.resize(video[0][0], factor)) - hog_size = first_hog.shape[0] - - frame_idx_map = [None] * sample_count - flat_stack = np.zeros((sample_count, hog_size)) - - for i in range(sample_count): - frame_idx = i * sample_step - img = video[frame_idx][0] - img = cls.resize(img, factor) - flat_stack[i] = hog.compute(img).transpose()[0] - frame_idx_map[i] = frame_idx - - return (flat_stack, frame_idx_map) - # Functions for making suggestions based on "feature stack" # These are common for all suggestion methods @staticmethod - def to_frame_idx_list(selected_list:list, frame_idx_map:dict) -> list: + def to_frame_idx_list( + selected_list: List[int], frame_idx_map: List[int] + ) -> List[int]: """Convert list of row indexes to list of frame indexes.""" return list(map(lambda x: frame_idx_map[x], selected_list)) @classmethod def feature_stack_to_suggestions( - cls, - feature_stack, frame_idx_map, - return_clusters=False, - **kwargs): + cls, + feature_stack: np.ndarray, + frame_idx_map: List[int], + return_clusters: bool = False, + **kwargs, + ) -> List[int]: """ Turns a feature stack matrix into a list of suggested frames. Args: - feature_stack: (n * features) matrix - frame_idx_map: n-length vector which gives frame_idx for each row in feature_stack - return_clusters (optional): return the intermediate result for debugging - i.e., a list that gives the list of suggested frames for each cluster + feature_stack: (n * features) matrix. + frame_idx_map: List indexed by rows of feature stack which gives + frame index for each row in feature_stack. This allows a + single frame to correspond to multiple rows in feature_stack. + return_clusters: Whether to return the intermediate result for + debugging, i.e., a list that gives the list of suggested + frames for each cluster. + + Returns: + List of frame index suggestions. """ selected_by_cluster = cls.feature_stack_to_clusters( - feature_stack=feature_stack, - frame_idx_map=frame_idx_map, - **kwargs) + feature_stack=feature_stack, frame_idx_map=frame_idx_map, **kwargs + ) - if return_clusters: return selected_by_cluster + if return_clusters: + return selected_by_cluster selected_list = cls.clusters_to_list( - selected_by_cluster=selected_by_cluster, - **kwargs) + selected_by_cluster=selected_by_cluster, **kwargs + ) return selected_list @classmethod def feature_stack_to_clusters( - cls, - feature_stack, - frame_idx_map, - clusters=5, - per_cluster=5, - pca_components=50, - **kwargs): + cls, + feature_stack: np.ndarray, + frame_idx_map: List[int], + clusters: int = 5, + per_cluster: int = 5, + pca_components: int = 50, + **kwargs, + ) -> List[int]: """ - Turns feature stack matrix into list (per cluster) of list of frame indexes. + Runs PCA to generate clusters of frames based on given features. Args: - feature_stack: (n * features) matrix + feature_stack: (n * features) matrix. + frame_idx_map: List indexed by rows of feature stack which gives + frame index for each row in feature_stack. This allows a + single frame to correspond to multiple rows in feature_stack. clusters: number of clusters - per_clusters: how many suggestions to take from each cluster (at most) - pca_components: for reducing feature space before clustering + per_cluster: How many suggestions (at most) to take from each + cluster. + pca_components: Number of PCA components, for reducing feature + space before clustering Returns: - list of lists - for each cluster, a list of frame indexes + A list of lists: + * for each cluster, a list of frame indices. """ stack_height = feature_stack.shape[0] @@ -284,7 +290,7 @@ def feature_stack_to_clusters( selected_by_cluster = [] selected_set = set() for i in range(clusters): - cluster_items, = np.where(row_labels==i) + cluster_items, = np.where(row_labels == i) # convert from row indexes to frame indexes cluster_items = cls.to_frame_idx_list(cluster_items, frame_idx_map) @@ -294,32 +300,37 @@ def feature_stack_to_clusters( cluster_items = list(set(cluster_items) - selected_set) # pick [per_cluster] items from this cluster - samples_from_bin = np.random.choice(cluster_items, min(len(cluster_items), per_cluster), False) + samples_from_bin = np.random.choice( + cluster_items, min(len(cluster_items), per_cluster), False + ) samples_from_bin.sort() selected_by_cluster.append(samples_from_bin) - selected_set = selected_set.union( set(samples_from_bin) ) + selected_set = selected_set.union(set(samples_from_bin)) return selected_by_cluster @classmethod - def clusters_to_list(cls, selected_by_cluster, interleave:bool = True, **kwargs) -> list: + def clusters_to_list( + cls, selected_by_cluster: List[List[int]], interleave: bool = True, **kwargs + ) -> list: """ - Turns list (per cluster) of lists of frame index into single list of frame indexes. + Merges per cluster suggestion lists into single list for entire video. Args: - selected_by_cluster: the list of lists of row indexes - frame_idx_map: map from row index to frame index - interleave: whether we want to interleave suggestions from clusters + selected_by_cluster: The list of lists of row indexes. + interleave: Whether to interleave suggestions from clusters. Returns: - list of frame index + List of frame indices. """ if interleave: # cycle clusters - all_selected = itertools.chain.from_iterable(itertools.zip_longest(*selected_by_cluster)) + all_selected = itertools.chain.from_iterable( + itertools.zip_longest(*selected_by_cluster) + ) # remove Nones and convert back to list - all_selected = list(filter(lambda x:x is not None, all_selected)) + all_selected = list(filter(lambda x: x is not None, all_selected)) else: all_selected = list(itertools.chain.from_iterable(selected_by_cluster)) all_selected.sort() @@ -331,34 +342,41 @@ def clusters_to_list(cls, selected_by_cluster, interleave:bool = True, **kwargs) # Utility functions @classmethod - def get_scale_factor(cls, video) -> int: + def get_scale_factor(cls, video: "Video") -> int: + """ + Determines how much we need to scale to get video within size. + + Size is specified by :attr:`rescale_below`. + """ factor = 1 if cls.rescale: largest_dim = max(video.height, video.width) factor = 1 - while largest_dim/factor > cls.rescale_below: + while largest_dim / factor > cls.rescale_below: factor += 1 return factor - @classmethod - def resize(cls, img, factor) -> np.ndarray: + @staticmethod + def resize(img: np.ndarray, factor: float) -> np.ndarray: + """Resizes frame image by scaling factor.""" h, w, _ = img.shape if factor != 1: - return cv2.resize(img, (h//factor, w//factor)) + return cv2.resize(img, (h // factor, w // factor)) else: return img + if __name__ == "__main__": # load some images filename = "tests/data/videos/centered_pair_small.mp4" filename = "files/190605_1509_frame23000_24000.sf.mp4" video = Video.from_filename(filename) - debug=False + debug = False - x = VideoFrameSuggestions.hog(video=video, sample_step=20, - clusters=5, per_cluster=5, - return_clusters=debug) + x = VideoFrameSuggestions.hog( + video=video, sample_step=20, clusters=5, per_cluster=5, return_clusters=debug + ) print(x) if debug: @@ -410,4 +428,4 @@ def resize(cls, img, factor) -> np.ndarray: # print(len(kp)) # print(des.shape) -# print(VideoFrameSuggestions.suggest(video, dict(method="pca"))) \ No newline at end of file +# print(VideoFrameSuggestions.suggest(video, dict(method="pca"))) diff --git a/sleap/gui/training_editor.py b/sleap/gui/training_editor.py index 9a1f4e6dc..77a09d644 100644 --- a/sleap/gui/training_editor.py +++ b/sleap/gui/training_editor.py @@ -1,4 +1,7 @@ -import os +""" +Module for viewing and modifying training profiles. +""" + import attr import cattr from typing import Optional @@ -7,21 +10,41 @@ from PySide2 import QtWidgets -from sleap.io.dataset import Labels from sleap.gui.formbuilder import YamlFormWidget -class TrainingEditor(QtWidgets.QDialog): - def __init__(self, profile_filename: Optional[str]=None, saved_files: list=[], *args, **kwargs): +class TrainingEditor(QtWidgets.QDialog): + """ + Dialog for viewing and modifying training profiles. + + Args: + profile_filename: Path to saved training profile to view. + saved_files: When user saved profile, it's path is added to this + list (which will be updated in code that created TrainingEditor). + """ + + def __init__( + self, + profile_filename: Optional[str] = None, + saved_files: list = [], + *args, + **kwargs + ): super(TrainingEditor, self).__init__() - form_yaml = resource_filename(Requirement.parse("sleap"),"sleap/config/training_editor.yaml") + form_yaml = resource_filename( + Requirement.parse("sleap"), "sleap/config/training_editor.yaml" + ) self.form_widgets = dict() - self.form_widgets["model"] = YamlFormWidget(form_yaml, "model", "Network Architecture") - self.form_widgets["datagen"] = YamlFormWidget(form_yaml, "datagen", "Data Generation/Preprocessing") + self.form_widgets["model"] = YamlFormWidget( + form_yaml, "model", "Network Architecture" + ) + self.form_widgets["datagen"] = YamlFormWidget( + form_yaml, "datagen", "Data Generation/Preprocessing" + ) self.form_widgets["trainer"] = YamlFormWidget(form_yaml, "trainer", "Trainer") - self.form_widgets["output"] = YamlFormWidget(form_yaml, "output",) + self.form_widgets["output"] = YamlFormWidget(form_yaml, "output") self.form_widgets["buttons"] = YamlFormWidget(form_yaml, "buttons") self.form_widgets["buttons"].mainAction.connect(self._save_as) @@ -47,10 +70,12 @@ def __init__(self, profile_filename: Optional[str]=None, saved_files: list=[], * @property def profile_filename(self): + """Returns path to currently loaded training profile.""" return self._profile_filename @profile_filename.setter def profile_filename(self, val): + """Sets path to (and loads) training profile.""" self._profile_filename = val # set window title self.setWindowTitle(self.profile_filename) @@ -64,7 +89,8 @@ def _layout_widget(layout): widget.setLayout(layout) return widget - def _load_profile(self, profile_filename:str): + def _load_profile(self, profile_filename: str): + """Loads training profile settings from file.""" from sleap.nn.model import ModelOutputType from sleap.nn.training import TrainingJob @@ -80,18 +106,13 @@ def _load_profile(self, profile_filename:str): for name in "datagen,trainer,output".split(","): self.form_widgets[name].set_form_data(job_dict["trainer"]) - def _update_profile(self): - # update training job from params in form - trainer = job.trainer - for key, val in form_data.items(): - # check if form field matches attribute of Trainer object - if key in dir(trainer): - setattr(trainer, key, val) - def _save_as(self): + """Shows dialog to save training profile.""" # Show "Save" dialog - save_filename, _ = QtWidgets.QFileDialog.getSaveFileName(self, caption="Save As...", dir=None, filter="Profile JSON (*.json)") + save_filename, _ = QtWidgets.QFileDialog.getSaveFileName( + self, caption="Save As...", dir=None, filter="Profile JSON (*.json)" + ) if len(save_filename): from sleap.nn.model import Model, ModelOutputType @@ -100,35 +121,47 @@ def _save_as(self): # Construct Model model_data = self.form_widgets["model"].get_form_data() - arch = dict(LeapCNN=leap.LeapCNN, - StackedHourglass=hourglass.StackedHourglass, - UNet=unet.UNet, - StackedUNet=unet.StackedUNet, - )[model_data["arch"]] - - output_type = dict(confmaps=ModelOutputType.CONFIDENCE_MAP, - pafs=ModelOutputType.PART_AFFINITY_FIELD, - centroids=ModelOutputType.CENTROIDS - )[model_data["output_type"]] - - backbone_kwargs = {key:val for key, val in model_data.items() - if key in attr.fields_dict(arch).keys()} + arch = dict( + LeapCNN=leap.LeapCNN, + StackedHourglass=hourglass.StackedHourglass, + UNet=unet.UNet, + StackedUNet=unet.StackedUNet, + )[model_data["arch"]] + + output_type = dict( + confmaps=ModelOutputType.CONFIDENCE_MAP, + pafs=ModelOutputType.PART_AFFINITY_FIELD, + centroids=ModelOutputType.CENTROIDS, + )[model_data["output_type"]] + + backbone_kwargs = { + key: val + for key, val in model_data.items() + if key in attr.fields_dict(arch).keys() + } model = Model(output_type=output_type, backbone=arch(**backbone_kwargs)) # Construct Trainer - trainer_data = {**self.form_widgets["datagen"].get_form_data(), - **self.form_widgets["output"].get_form_data(), - **self.form_widgets["trainer"].get_form_data(), - } - - trainer_kwargs = {key:val for key, val in trainer_data.items() - if key in attr.fields_dict(Trainer).keys()} + trainer_data = { + **self.form_widgets["datagen"].get_form_data(), + **self.form_widgets["output"].get_form_data(), + **self.form_widgets["trainer"].get_form_data(), + } + + trainer_kwargs = { + key: val + for key, val in trainer_data.items() + if key in attr.fields_dict(Trainer).keys() + } trainer = Trainer(**trainer_kwargs) # Construct TrainingJob - training_job_kwargs = {key:val for key, val in trainer_data.items() - if key in attr.fields_dict(TrainingJob).keys()} + training_job_kwargs = { + key: val + for key, val in trainer_data.items() + if key in attr.fields_dict(TrainingJob).keys() + } training_job = TrainingJob(model, trainer, **training_job_kwargs) # Write the file @@ -140,6 +173,7 @@ def _save_as(self): self.close() + if __name__ == "__main__": import sys @@ -150,4 +184,4 @@ def _save_as(self): app = QtWidgets.QApplication([]) win = TrainingEditor(profile_filename) win.show() - app.exec_() \ No newline at end of file + app.exec_() diff --git a/sleap/gui/video.py b/sleap/gui/video.py index 5bd3a472a..b1e79deb7 100644 --- a/sleap/gui/video.py +++ b/sleap/gui/video.py @@ -6,33 +6,40 @@ Example usage: >>> my_video = Video(...) >>> my_instance = Instance(...) - >>> color = (r, g, b) - >>> vp = QtVideoPlayer(video = my_video) - >>> vp.addInstance(instance = my_instance, color) + >>> vp = QtVideoPlayer(video=my_video) + >>> vp.addInstance(instance=my_instance, color=(r, g, b)) + """ from PySide2 import QtWidgets -from PySide2.QtWidgets import QApplication, QVBoxLayout, QWidget -from PySide2.QtWidgets import QLabel, QPushButton, QSlider -from PySide2.QtWidgets import QAction - -from PySide2.QtWidgets import QGraphicsView, QGraphicsScene +from PySide2.QtWidgets import ( + QApplication, + QVBoxLayout, + QWidget, + QGraphicsView, + QGraphicsScene, +) from PySide2.QtGui import QImage, QPixmap, QPainter, QPainterPath, QTransform from PySide2.QtGui import QPen, QBrush, QColor, QFont from PySide2.QtGui import QKeyEvent from PySide2.QtCore import Qt, Signal, Slot -from PySide2.QtCore import QRectF, QLineF, QPointF, QMarginsF, QSizeF +from PySide2.QtCore import QRectF, QPointF, QMarginsF import math -import numpy as np -from typing import Callable +from typing import Callable, List, Optional, Union from PySide2.QtWidgets import QGraphicsItem, QGraphicsObject + # The PySide2.QtWidgets.QGraphicsObject class provides a base class for all graphics items that require signals, slots and properties. -from PySide2.QtWidgets import QGraphicsEllipseItem, QGraphicsLineItem, QGraphicsTextItem, QGraphicsRectItem +from PySide2.QtWidgets import ( + QGraphicsEllipseItem, + QGraphicsLineItem, + QGraphicsTextItem, + QGraphicsRectItem, +) from sleap.skeleton import Skeleton from sleap.instance import Instance, Point @@ -46,15 +53,18 @@ class QtVideoPlayer(QWidget): """ Main QWidget for displaying video with skeleton instances. - Args: - video (optional): the :class:`Video` to display - Signals: - changedPlot: Emitted whenever the plot is redrawn - changedData: Emitted whenever data is changed by user + * changedPlot: Emitted whenever the plot is redrawn + * changedData: Emitted whenever data is changed by user + + Attributes: + video: The :class:`Video` to display + color_manager: A :class:`TrackColorManager` object which determines + which color to show the instances. + """ - changedPlot = Signal(QWidget, int, int) + changedPlot = Signal(QWidget, int, Instance) changedData = Signal(Instance) def __init__(self, video: Video = None, color_manager=None, *args, **kwargs): @@ -62,10 +72,10 @@ def __init__(self, video: Video = None, color_manager=None, *args, **kwargs): self._shift_key_down = False self.frame_idx = -1 - self._color_manager = color_manager + self.color_manager = color_manager self.view = GraphicsView() - self.seekbar = VideoSlider(color_manager=self._color_manager) + self.seekbar = VideoSlider(color_manager=self.color_manager) self.seekbar.valueChanged.connect(lambda evt: self.plot(self.seekbar.value())) self.seekbar.keyPress.connect(self.keyPressEvent) self.seekbar.keyRelease.connect(self.keyReleaseEvent) @@ -74,7 +84,7 @@ def __init__(self, video: Video = None, color_manager=None, *args, **kwargs): self.splitter = QtWidgets.QSplitter(Qt.Vertical) self.splitter.addWidget(self.view) self.splitter.addWidget(self.seekbar) - self.seekbar.updatedTracks.connect(lambda: self.splitter.refresh()) + self.seekbar.heightUpdated.connect(lambda: self.splitter.refresh()) self.layout = QVBoxLayout() self.layout.addWidget(self.splitter) @@ -104,11 +114,14 @@ def load_video(self, video: Video, initial_frame=0, plot=True): # self.seekbar.setTickInterval(1) self.seekbar.setValue(self.frame_idx) self.seekbar.setMinimum(0) - self.seekbar.setMaximum(self.video.frames - 1) + self.seekbar.setMaximum(self.video.last_frame_idx) self.seekbar.setEnabled(True) if plot: - self.plot(initial_frame) + try: + self.plot(initial_frame) + except: + pass def reset(self): """ Reset viewer by removing all video data. @@ -121,16 +134,24 @@ def reset(self): @property def instances(self): + """Returns list of all `QtInstance` objects in view.""" return self.view.instances @property def selectable_instances(self): + """Returns list of selectable `QtInstance` objects in view.""" return self.view.selectable_instances @property def predicted_instances(self): + """Returns list of predicted `QtInstance` objects in view.""" return self.view.predicted_instances + @property + def scene(self): + """Returns `QGraphicsScene` for viewer.""" + return self.view.scene + def addInstance(self, instance, **kwargs): """Add a skeleton instance to the video. @@ -142,23 +163,23 @@ def addInstance(self, instance, **kwargs): # Check if instance is an Instance (or subclass of Instance) if issubclass(type(instance), Instance): instance = QtInstance(instance=instance, **kwargs) - if type(instance) != QtInstance: return + if type(instance) != QtInstance: + return self.view.scene.addItem(instance) # connect signal from instance instance.changedData.connect(self.changedData) - # connect signal so we can adjust QtNodeLabel positions after zoom self.view.updatedViewer.connect(instance.updatePoints) - def plot(self, idx=None): + def plot(self, idx: Optional[int] = None): """ Do the actual plotting of the video frame. Args: - idx (optional): Go to frame idx. If None, stay on current frame. + idx: Go to frame idx. If None, stay on current frame. """ if self.video is None: @@ -175,28 +196,20 @@ def plot(self, idx=None): self.frame_idx = idx self.seekbar.setValue(self.frame_idx) - # Save index of selected instance - selected_idx = self.view.getSelection() - selected_idx = -1 if selected_idx is None else selected_idx # use -1 for no selection + # Store which Instance is selected + selected_inst = self.view.getSelectionInstance() # Clear existing objects self.view.clear() # Convert ndarray to QImage - # TODO: handle RGB and other formats - # https://stackoverflow.com/questions/34232632/convert-python-opencv-image-numpy-array-to-pyqt-qpixmap-image - # https://stackoverflow.com/questions/55063499/pyqt5-convert-cv2-image-to-qimage - # image = QImage(frame.copy().data, frame.shape[1], frame.shape[0], frame.shape[1], QImage.Format_Grayscale8) - # image = QImage(frame.copy().data, frame.shape[1], frame.shape[0], QImage.Format_Grayscale8) - - # Magic bullet: image = qimage2ndarray.array2qimage(frame) # Display image self.view.setImage(image) - # Emit signal (it's better to use the signal than a callback) - self.changedPlot.emit(self, idx, selected_idx) + # Emit signal + self.changedPlot.emit(self, idx, selected_inst) def nextFrame(self, dt=1): """ Go to next frame. @@ -245,20 +258,33 @@ def zoomToFit(self): if not zoom_rect.size().isEmpty(): self.view.zoomToRect(zoom_rect) - def onSequenceSelect(self, seq_len: int, on_success: Callable, - on_each = None, on_failure = None): + def onSequenceSelect( + self, + seq_len: int, + on_success: Callable, + on_each: Optional[Callable] = None, + on_failure: Optional[Callable] = None, + ): """ - Collect a sequence of instances (through user selection) and call `on_success`. - If the user cancels (by unselecting without new selection), call `on_failure`. + Collect a sequence of instances (through user selection). - Args: - seq_len: number of instances we expect user to select - on_success: callback after use has selected desired number of instances - on_failure (optional): callback if user cancels selection + When the sequence is complete, the `on_success` callback is called. + After each selection in sequence, the `on_each` callback is called + (if given). If the user cancels (by unselecting without new + selection), the `on_failure` callback is called (if given). Note: If successful, we call >>> on_success(sequence_of_selected_instance_indexes) + + Args: + seq_len: Number of instances we want to collect in sequence. + on_success: Callback for when user has selected desired number of + instances. + on_each: Callback after user selects each instance. + on_failure: Callback if user cancels process before selecting + enough instances. + """ indexes = [] @@ -266,11 +292,13 @@ def onSequenceSelect(self, seq_len: int, on_success: Callable, indexes.append(self.view.getSelection()) # Define function that will be called when user selects another instance - def handle_selection(seq_len=seq_len, - indexes=indexes, - on_success=on_success, - on_each=on_each, - on_failure=on_failure): + def handle_selection( + seq_len=seq_len, + indexes=indexes, + on_success=on_success, + on_each=on_each, + on_failure=on_failure, + ): # Get the index of the currently selected instance new_idx = self.view.getSelection() # If something is selected, add it to the list @@ -300,32 +328,67 @@ def handle_selection(seq_len=seq_len, on_each(indexes) @staticmethod - def _signal_once(signal, callback): + def _signal_once(signal: Signal, callback: Callable): + """ + Connects callback for next occurrence of signal. + + Args: + signal: The signal on which we want callback to be called. + callback: The function that should be called just once, the next + time the signal is emitted. + + Returns: + None. + """ + def call_once(*args): signal.disconnect(call_once) callback(*args) + signal.connect(call_once) def onPointSelection(self, callback: Callable): + """ + Starts mode for user to click point, callback called when finished. + + Args: + callback: The function called after user clicks point, should + take x and y as arguments. + + Returns: + None. + """ self.view.click_mode = "point" self.view.setCursor(Qt.CrossCursor) self._signal_once(self.view.pointSelected, callback) def onAreaSelection(self, callback: Callable): + """ + Starts mode for user to select area, callback called when finished. + + Args: + callback: The function called after user clicks point, should + take x0, y0, x1, y1 as arguments. + + Returns: + None. + """ self.view.click_mode = "area" self.view.setCursor(Qt.CrossCursor) self._signal_once(self.view.areaSelected, callback) def keyReleaseEvent(self, event: QKeyEvent): + """ + Custom event handler, tracks when user releases modifier (shift) key. + """ if event.key() == Qt.Key.Key_Shift: self._shift_key_down = False event.ignore() def keyPressEvent(self, event: QKeyEvent): - """ Custom event handler. - Move between frames, toggle display of edges/labels, and select instances. """ - ignore = False + Custom event handler, allows navigation and selection within view. + """ frame_t0 = self.frame_idx if event.key() == Qt.Key.Key_Shift: @@ -351,13 +414,11 @@ def keyPressEvent(self, event: QKeyEvent): self.view.nextSelection() elif event.key() < 128 and chr(event.key()).isnumeric(): # decrement by 1 since instances are 0-indexed - self.view.selectInstance(int(chr(event.key()))-1) + self.view.selectInstance(int(chr(event.key())) - 1) else: - event.ignore() # Kicks the event up to parent - # print(event.key()) + event.ignore() # Kicks the event up to parent # If user is holding down shift and action resulted in moving to another frame - # event.modifiers() == Qt.ShiftModifier and if self._shift_key_down and frame_t0 != self.frame_idx: # If there's no select, start seekbar selection at frame before action start, end = self.seekbar.getSelection() @@ -366,37 +427,44 @@ def keyPressEvent(self, event: QKeyEvent): # Set endpoint to frame after action self.seekbar.endSelection(self.frame_idx, update=True) + class GraphicsView(QGraphicsView): """ - QGraphicsView used by QtVideoPlayer. + Custom `QGraphicsView` used by `QtVideoPlayer`. - This contains elements for display of video and event handlers for zoom/selection. + This contains elements for display of video and event handlers for zoom + and selection of instances in view. Signals: - updatedViewer: Emitted after update to view (e.g., zoom) + * updatedViewer: Emitted after update to view (e.g., zoom). Used internally so we know when to update points for each instance. - updatedSelection: Emitted after the user has selected/unselected an instance - instanceDoubleClicked: Emitted after an instance is double clicked - - leftMouseButtonPressed - rightMouseButtonPressed - leftMouseButtonReleased - rightMouseButtonReleased - leftMouseButtonDoubleClicked - rightMouseButtonDoubleClicked + * updatedSelection: Emitted after the user has (un)selected an instance. + * instanceDoubleClicked: Emitted after an instance is double-clicked. + Passes the :class:`Instance` that was double-clicked. + * areaSelected: Emitted after user selects an area when in "area" + click mode. Passes x0, y0, x1, y1 for selected box coordinates. + * pointSelected: Emitted after user clicks a point (in "point" click + mode.) Passes x, y coordinates of point. + * leftMouseButtonPressed: Emitted by event handler. + * rightMouseButtonPressed: Emitted by event handler. + * leftMouseButtonReleased: Emitted by event handler. + * rightMouseButtonReleased: Emitted by event handler. + * leftMouseButtonDoubleClicked: Emitted by event handler. + * rightMouseButtonDoubleClicked: Emitted by event handler. + """ updatedViewer = Signal() updatedSelection = Signal() instanceDoubleClicked = Signal(Instance) + areaSelected = Signal(float, float, float, float) + pointSelected = Signal(float, float) leftMouseButtonPressed = Signal(float, float) rightMouseButtonPressed = Signal(float, float) leftMouseButtonReleased = Signal(float, float) rightMouseButtonReleased = Signal(float, float) leftMouseButtonDoubleClicked = Signal(float, float) rightMouseButtonDoubleClicked = Signal(float, float) - areaSelected = Signal(float, float, float, float) - pointSelected = Signal(float, float) def __init__(self, *args, **kwargs): """ https://github.com/marcel-goldschen-ohm/PyQtImageViewer/blob/master/QtImageViewer.py """ @@ -409,7 +477,6 @@ def __init__(self, *args, **kwargs): self._pixmapHandle = None self.setRenderHint(QPainter.Antialiasing) - # self.setCacheMode(QGraphicsView.CacheNone) self.aspectRatioMode = Qt.KeepAspectRatio self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) @@ -421,12 +488,9 @@ def __init__(self, *args, **kwargs): self.zoomFactor = 1 anchor_mode = QGraphicsView.AnchorUnderMouse - # anchor_mode = QGraphicsView.AnchorViewCenter self.setTransformationAnchor(anchor_mode) - # self.scene.render() - - def hasImage(self): + def hasImage(self) -> bool: """ Returns whether or not the scene contains an image pixmap. """ return self._pixmapHandle is not None @@ -437,18 +501,27 @@ def clear(self): self._pixmapHandle = None self.scene.clear() - def setImage(self, image): - """ Set the scene's current image pixmap to the input QImage or QPixmap. - Raises a RuntimeError if the input image has type other than QImage or QPixmap. - :type image: QImage | QPixmap + def setImage(self, image: Union[QImage, QPixmap]): + """ + Set the scene's current image pixmap to the input QImage or QPixmap. + + Args: + image: The QPixmap or QImage to display. + + Raises: + RuntimeError: If the input image is not QImage or QPixmap + + Returns: + None. """ if type(image) is QPixmap: pixmap = image elif type(image) is QImage: - # pixmap = QPixmap.fromImage(image) pixmap = QPixmap(image) else: - raise RuntimeError("ImageViewer.setImage: Argument must be a QImage or QPixmap.") + raise RuntimeError( + "ImageViewer.setImage: Argument must be a QImage or QPixmap." + ) if self.hasImage(): self._pixmapHandle.setPixmap(pixmap) else: @@ -457,8 +530,7 @@ def setImage(self, image): self.updateViewer() def updateViewer(self): - """ Show current zoom (if showing entire image, apply current aspect ratio mode). - """ + """ Apply current zoom. """ if not self.hasImage(): return @@ -472,39 +544,41 @@ def updateViewer(self): self.updatedViewer.emit() @property - def instances(self): + def instances(self) -> List["QtInstance"]: """ Returns a list of instances. - Order in list should match the order in which instances were added to scene. + Order should match the order in which instances were added to scene. """ - return [item for item in self.scene.items(Qt.SortOrder.AscendingOrder) - if type(item) == QtInstance and not item.predicted] + return list(filter(lambda x: not x.predicted, self.all_instances)) @property - def selectable_instances(self): - return [item for item in self.scene.items(Qt.SortOrder.AscendingOrder) - if type(item) == QtInstance and item.selectable] + def predicted_instances(self) -> List["QtInstance"]: + """ + Returns a list of predicted instances. + + Order should match the order in which instances were added to scene. + """ + return list(filter(lambda x: not x.predicted, self.all_instances)) @property - def predicted_instances(self): + def selectable_instances(self) -> List["QtInstance"]: """ - Returns a list of predicted instances. + Returns a list of instances which user can select. - Order in list should match the order in which instances were added to scene. + Order should match the order in which instances were added to scene. """ - return [item for item in self.scene.items(Qt.SortOrder.AscendingOrder) - if type(item) == QtInstance and item.predicted] + return list(filter(lambda x: x.selectable, self.all_instances)) @property - def all_instances(self): + def all_instances(self) -> List["QtInstance"]: """ - Returns a list of instances and predicted instances. + Returns a list of all `QtInstance` objects in scene. - Order in list should match the order in which instances were added to scene. + Order should match the order in which instances were added to scene. """ - return [item for item in self.scene.items(Qt.SortOrder.AscendingOrder) - if type(item) == QtInstance] + scene_items = self.scene.items(Qt.SortOrder.AscendingOrder) + return list(filter(lambda x: isinstance(x, QtInstance), scene_items)) def clearSelection(self, signal=True): """ Clear instance skeleton selection. @@ -512,56 +586,65 @@ def clearSelection(self, signal=True): for instance in self.all_instances: instance.selected = False # signal that the selection has changed (so we can update visual display) - if signal: self.updatedSelection.emit() + if signal: + self.updatedSelection.emit() def nextSelection(self): """ Select next instance (or first, if none currently selected). """ instances = self.selectable_instances - if len(instances) == 0: return - select_inst = instances[0] # default to selecting first instance + if len(instances) == 0: + return + select_inst = instances[0] # default to selecting first instance select_idx = 0 for idx, instance in enumerate(instances): if instance.selected: instance.selected = False - select_idx = (idx+1)%len(instances) + select_idx = (idx + 1) % len(instances) select_inst = instances[select_idx] break select_inst.selected = True # signal that the selection has changed (so we can update visual display) self.updatedSelection.emit() - def selectInstance(self, select_idx, from_all=False, signal=True): + def selectInstance(self, select: Union[Instance, int], signal=True): """ - Select a particular skeleton instance. + Select a particular instance in view. Args: - select_idx: index of skeleton to select + select: Either `Instance` or index of instance in view. + signal: Whether to emit updatedSelection. + + Returns: + None """ - instances = self.selectable_instances if not from_all else self.all_instances self.clearSelection(signal=False) - if select_idx < len(instances): - for idx, instance in enumerate(instances): - instance.selected = (select_idx == idx) + + for idx, instance in enumerate(self.all_instances): + instance.selected = select == idx or select == instance.instance + # signal that the selection has changed (so we can update visual display) - if signal: self.updatedSelection.emit() + if signal: + self.updatedSelection.emit() - def getSelection(self): + def getSelection(self) -> int: """ Returns the index of the currently selected instance. If no instance selected, returns None. """ instances = self.all_instances - if len(instances) == 0: return None + if len(instances) == 0: + return None for idx, instance in enumerate(instances): if instance.selected: return idx - def getSelectionInstance(self): + def getSelectionInstance(self) -> Instance: """ Returns the currently selected instance. If no instance selected, returns None. """ instances = self.all_instances - if len(instances) == 0: return None + if len(instances) == 0: + return None for idx, instance in enumerate(instances): if instance.selected: return instance.instance @@ -603,7 +686,7 @@ def mouseReleaseEvent(self, event): QGraphicsView.mouseReleaseEvent(self, event) scenePos = self.mapToScene(event.pos()) # check if mouse moved during click - has_moved = (event.pos() != self._down_pos) + has_moved = event.pos() != self._down_pos if event.button() == Qt.LeftButton: if self.click_mode == "": @@ -611,12 +694,17 @@ def mouseReleaseEvent(self, event): if not has_moved: # When just a tap, see if there's an item underneath to select clicked = self.scene.items(scenePos, Qt.IntersectsItemBoundingRect) - clicked_instances = [item for item in clicked - if type(item) == QtInstance and item.selectable] + clicked_instances = [ + item + for item in clicked + if type(item) == QtInstance and item.selectable + ] # We only handle single instance selection so pick at most one from list - clicked_instance = clicked_instances[0] if len(clicked_instances) else None + clicked_instance = ( + clicked_instances[0] if len(clicked_instances) else None + ) for idx, instance in enumerate(self.selectable_instances): - instance.selected = (instance == clicked_instance) + instance.selected = instance == clicked_instance # If we want to allow selection of multiple instances, do this: # instance.selected = (instance in clicked) self.updatedSelection.emit() @@ -626,10 +714,11 @@ def mouseReleaseEvent(self, event): selection_rect = self.scene.selectionArea().boundingRect() self.areaSelected.emit( - selection_rect.left(), - selection_rect.top(), - selection_rect.right(), - selection_rect.bottom()) + selection_rect.left(), + selection_rect.top(), + selection_rect.right(), + selection_rect.bottom(), + ) elif self.click_mode == "point": selection_point = scenePos self.pointSelected.emit(scenePos.x(), scenePos.y()) @@ -644,7 +733,7 @@ def mouseReleaseEvent(self, event): elif event.button() == Qt.RightButton: if self.canZoom: zoom_rect = self.scene.selectionArea().boundingRect() - self.scene.setSelectionArea(QPainterPath()) # clear selection + self.scene.setSelectionArea(QPainterPath()) # clear selection self.zoomToRect(zoom_rect) self.setDragMode(QGraphicsView.NoDrag) self.rightMouseButtonReleased.emit(scenePos.x(), scenePos.y()) @@ -659,13 +748,13 @@ def zoomToRect(self, zoom_rect: QRectF): Args: zoom_rect: The `QRectF` to which we want to zoom. - relative: Controls whether rect is relative to current zoom. """ - if zoom_rect.isNull(): return + if zoom_rect.isNull(): + return - scale_h = self.scene.height()/zoom_rect.height() - scale_w = self.scene.width()/zoom_rect.width() + scale_h = self.scene.height() / zoom_rect.height() + scale_w = self.scene.width() / zoom_rect.width() scale = min(scale_h, scale_w) self.zoomFactor = scale @@ -677,7 +766,7 @@ def clearZoom(self): """ self.zoomFactor = 1 - def instancesBoundingRect(self, margin=0): + def instancesBoundingRect(self, margin: float = 0) -> QRectF: """ Returns a rect which contains all displayed skeleton instances. @@ -694,7 +783,7 @@ def instancesBoundingRect(self, margin=0): return rect def mouseDoubleClickEvent(self, event): - """ Custom event handler. Show entire image. + """ Custom event handler, clears zoom. """ scenePos = self.mapToScene(event.pos()) if event.button() == Qt.LeftButton: @@ -731,10 +820,12 @@ def wheelEvent(self, event): pass def keyPressEvent(self, event): - event.ignore() # Kicks the event up to parent + """Custom event hander, disables default QGraphicsView behavior.""" + event.ignore() # Kicks the event up to parent def keyReleaseEvent(self, event): - event.ignore() # Kicks the event up to parent + """Custom event hander, disables default QGraphicsView behavior.""" + event.ignore() # Kicks the event up to parent class QtNodeLabel(QGraphicsTextItem): @@ -782,20 +873,24 @@ def adjustPos(self, *args, **kwargs): if len(node.edges): edge_angles = sorted([edge.angle_to(node) for edge in node.edges]) - edge_angles.append(edge_angles[0] + math.pi*2) + edge_angles.append(edge_angles[0] + math.pi * 2) # Calculate size and bisector for each arc between adjacent edges - edge_arcs = [(edge_angles[i+1]-edge_angles[i], - edge_angles[i+1]/2+edge_angles[i]/2) - for i in range(len(edge_angles)-1)] + edge_arcs = [ + ( + edge_angles[i + 1] - edge_angles[i], + edge_angles[i + 1] / 2 + edge_angles[i] / 2, + ) + for i in range(len(edge_angles) - 1) + ] max_arc = sorted(edge_arcs)[-1] - shift_angle = max_arc[1] # this is the angle of the bisector - shift_angle %= 2*math.pi + shift_angle = max_arc[1] # this is the angle of the bisector + shift_angle %= 2 * math.pi # Use the _shift_factor to control how the label is positioned # relative to the node. # Shift factor of -1 means we shift label up/left by its height/width. - self._shift_factor_x = (math.cos(shift_angle)*.6) -.5 - self._shift_factor_y = (math.sin(shift_angle)*.6) -.5 + self._shift_factor_x = (math.cos(shift_angle) * 0.6) - 0.5 + self._shift_factor_y = (math.sin(shift_angle) * 0.6) - 0.5 # Since item doesn't scale when view is transformed (i.e., zoom) # we need to calculate bounding size in view manually. @@ -809,8 +904,10 @@ def adjustPos(self, *args, **kwargs): height = height / view.viewportTransform().m11() width = width / view.viewportTransform().m22() - self.setPos(self._anchor_x + width*self._shift_factor_x, - self._anchor_y + height*self._shift_factor_y) + self.setPos( + self._anchor_x + width * self._shift_factor_x, + self._anchor_y + height * self._shift_factor_y, + ) # Now apply these changes to the visual display self.adjustStyle() @@ -819,7 +916,9 @@ def adjustStyle(self): """ Update visual display of the label and its node. """ - complete_color = QColor(80, 194, 159) if self.node.point.complete else QColor(232, 45, 32) + complete_color = ( + QColor(80, 194, 159) if self.node.point.complete else QColor(232, 45, 32) + ) if self.predicted: self._base_font.setBold(False) @@ -833,13 +932,13 @@ def adjustStyle(self): elif self.node.point.complete: self._base_font.setBold(True) self.setFont(self._base_font) - self.setDefaultTextColor(complete_color) # greenish + self.setDefaultTextColor(complete_color) # greenish # FIXME: Adjust style of node here as well? # self.node.setBrush(complete_color) else: self._base_font.setBold(False) self.setFont(self._base_font) - self.setDefaultTextColor(complete_color) # redish + self.setDefaultTextColor(complete_color) # redish def boundingRect(self): """ Method required by Qt. @@ -887,9 +986,21 @@ class QtNode(QGraphicsEllipseItem): color: Color of the visual node item. callbacks: List of functions to call after we update to the `Point`. """ - def __init__(self, parent, point:Point, radius:float, color:list, node_name:str = None, - predicted=False, color_predicted=False, show_non_visible=True, - callbacks = None, *args, **kwargs): + + def __init__( + self, + parent, + point: Point, + radius: float, + color: list, + node_name: str = None, + predicted=False, + color_predicted=False, + show_non_visible=True, + callbacks=None, + *args, + **kwargs, + ): self._parent = parent self.point = point self.radius = radius @@ -902,7 +1013,15 @@ def __init__(self, parent, point:Point, radius:float, color:list, node_name:str self.callbacks = [] if callbacks is None else callbacks self.dragParent = False - super(QtNode, self).__init__(-self.radius, -self.radius, self.radius*2, self.radius*2, parent=parent, *args, **kwargs) + super(QtNode, self).__init__( + -self.radius, + -self.radius, + self.radius * 2, + self.radius * 2, + parent=parent, + *args, + **kwargs, + ) if node_name is not None: self.setToolTip(node_name) @@ -926,7 +1045,9 @@ def __init__(self, parent, point:Point, radius:float, color:list, node_name:str self.setFlag(QGraphicsItem.ItemIsMovable) self.pen_default = QPen(col_line, 1) - self.pen_default.setCosmetic(True) # https://stackoverflow.com/questions/13120486/adjusting-qpen-thickness-when-scaling-qgraphicsview + self.pen_default.setCosmetic( + True + ) # https://stackoverflow.com/questions/13120486/adjusting-qpen-thickness-when-scaling-qgraphicsview self.pen_missing = QPen(col_line, 1) self.pen_missing.setCosmetic(True) self.brush = QBrush(QColor(*self.color, a=128)) @@ -942,11 +1063,12 @@ def calls(self): if callable(callback): callback(self) - def updatePoint(self, user_change=True): - """ Method to update data for node/edge after user manipulates visual point. + def updatePoint(self, user_change: bool = True): + """ + Method to update data for node/edge when node position is manipulated. Args: - user_change (optional): Is this being called because of change by user? + user_change: Whether this being called because of change by user. """ self.point.x = self.scenePos().x() self.point.y = self.scenePos().y() @@ -957,12 +1079,12 @@ def updatePoint(self, user_change=True): self.setPen(self.pen_default) self.setBrush(self.brush) else: - radius = self.radius / 2. + radius = self.radius / 2.0 self.setPen(self.pen_missing) self.setBrush(self.brush_missing) if not self.show_non_visible: self.hide() - self.setRect(-radius, -radius, radius*2, radius*2) + self.setRect(-radius, -radius, radius * 2, radius * 2) for edge in self.edges: edge.updateEdge(self) @@ -973,13 +1095,15 @@ def updatePoint(self, user_change=True): self.calls() # Emit event if we're updating from a user change - if user_change: self._parent.changedData.emit(self._parent.instance) + if user_change: + self._parent.changedData.emit(self._parent.instance) def mousePressEvent(self, event): """ Custom event handler for mouse press. """ # Do nothing if node is from predicted instance - if self.parentObject().predicted: return + if self.parentObject().predicted: + return self.setCursor(Qt.ArrowCursor) @@ -1011,17 +1135,19 @@ def mousePressEvent(self, event): def mouseMoveEvent(self, event): """ Custom event handler for mouse move. """ - #print(event) + # print(event) if self.dragParent: self.parentObject().mouseMoveEvent(event) else: super(QtNode, self).mouseMoveEvent(event) - self.updatePoint(user_change=False) # don't count change until mouse release + self.updatePoint( + user_change=False + ) # don't count change until mouse release def mouseReleaseEvent(self, event): """ Custom event handler for mouse release. """ - #print(event) + # print(event) self.unsetCursor() if self.dragParent: self.parentObject().mouseReleaseEvent(event) @@ -1041,11 +1167,13 @@ def wheelEvent(self, event): event.accept() def mouseDoubleClickEvent(self, event): + """Custom event handler to emit signal on event.""" scene = self.scene() if scene is not None: view = scene.views()[0] view.instanceDoubleClicked.emit(self.parentObject().instance) + class QtEdge(QGraphicsLineItem): """ QGraphicsLineItem to handle display of edge between skeleton instance nodes. @@ -1053,27 +1181,46 @@ class QtEdge(QGraphicsLineItem): Args: src: The `QtNode` source node for the edge. dst: The `QtNode` destination node for the edge. + color: Color as (r, g, b) tuple. + show_non_visible: Whether to show "non-visible" nodes/edges. """ - def __init__(self, parent, src:QtNode, dst:QtNode, color, - show_non_visible=True, - *args, **kwargs): + + def __init__( + self, + parent, + src: QtNode, + dst: QtNode, + color, + show_non_visible=True, + *args, + **kwargs, + ): self.src = src self.dst = dst self.show_non_visible = show_non_visible - super(QtEdge, self).__init__(self.src.point.x, self.src.point.y, self.dst.point.x, self.dst.point.y, parent=parent, *args, **kwargs) + super(QtEdge, self).__init__( + self.src.point.x, + self.src.point.y, + self.dst.point.x, + self.dst.point.y, + parent=parent, + *args, + **kwargs, + ) pen = QPen(QColor(*color), 1) pen.setCosmetic(True) self.setPen(pen) self.full_opacity = 1 - def connected_to(self, node): + def connected_to(self, node: QtNode): """ Return the other node along the edge. Args: node: One of the edge's nodes. + Returns: The other node (or None if edge doesn't have node). """ @@ -1084,7 +1231,7 @@ def connected_to(self, node): return None - def angle_to(self, node): + def angle_to(self, node: QtNode) -> float: """ Returns the angle from one edge node to the other. @@ -1099,17 +1246,20 @@ def angle_to(self, node): y = to.point.y - node.point.y return math.atan2(y, x) - def updateEdge(self, node): + def updateEdge(self, node: QtNode): """ Updates the visual display of node. Args: node: The node to update. + + Returns: + None. """ if self.src.point.visible and self.dst.point.visible: self.full_opacity = 1 else: - self.full_opacity = .5 if self.show_non_visible else 0 + self.full_opacity = 0.5 if self.show_non_visible else 0 self.setOpacity(self.full_opacity) if node == self.src: @@ -1131,22 +1281,37 @@ class QtInstance(QGraphicsObject): and handles the events to manipulate the skeleton within a video frame (i.e., moving, rotating, marking nodes). - It should be instatiated with a `Skeleton` or `Instance` - and added to the relevant `QGraphicsScene`. + It should be instantiated with an `Instance` and added to the relevant + `QGraphicsScene`. When instantiated, it creates `QtNode`, `QtEdge`, and `QtNodeLabel` items as children of itself. + + Args: + instance: The :class:`Instance` to show. + predicted: Whether this is a predicted instance. + color_predicted: Whether to show predicted instance in color. + color: Color of the visual item. + markerRadius: Radius of nodes. + show_non_visible: Whether to show "non-visible" nodes/edges. + """ changedData = Signal(Instance) - def __init__(self, skeleton:Skeleton = None, instance: Instance = None, - predicted=False, color_predicted=False, - color=(0, 114, 189), markerRadius=4, - show_non_visible=True, - *args, **kwargs): + def __init__( + self, + instance: Instance = None, + predicted=False, + color_predicted=False, + color=(0, 114, 189), + markerRadius=4, + show_non_visible=True, + *args, + **kwargs, + ): super(QtInstance, self).__init__(*args, **kwargs) - self.skeleton = skeleton if instance is None else instance.skeleton + self.skeleton = instance.skeleton self.instance = instance self.predicted = predicted self.color_predicted = color_predicted @@ -1162,8 +1327,6 @@ def __init__(self, skeleton:Skeleton = None, instance: Instance = None, self.labels_shown = True self._selected = False self._bounding_rect = QRectF() - #self.setFlag(QGraphicsItem.ItemIsMovable) - #self.setFlag(QGraphicsItem.ItemIsSelectable) if self.predicted: self.setZValue(0) @@ -1181,7 +1344,6 @@ def __init__(self, skeleton:Skeleton = None, instance: Instance = None, self.track_label = QtTextWithBackground(parent=self) self.track_label.setDefaultTextColor(QColor(*self.color)) - self.track_label.setFlag(QGraphicsItem.ItemIgnoresTransformations) instance_label_text = "" if self.instance.track is not None: @@ -1190,15 +1352,23 @@ def __init__(self, skeleton:Skeleton = None, instance: Instance = None, track_name = "[none]" instance_label_text += f"Track: {track_name}" if hasattr(self.instance, "score"): - instance_label_text += f"
Prediction Score: {round(self.instance.score, 2)}" + instance_label_text += ( + f"
Prediction Score: {round(self.instance.score, 2)}" + ) self.track_label.setHtml(instance_label_text) # Add nodes for (node, point) in self.instance.nodes_points: - node_item = QtNode(parent=self, point=point, node_name=node.name, - predicted=self.predicted, color_predicted=self.color_predicted, - color=self.color, radius=self.markerRadius, - show_non_visible=self.show_non_visible) + node_item = QtNode( + parent=self, + point=point, + node_name=node.name, + predicted=self.predicted, + color_predicted=self.color_predicted, + color=self.color, + radius=self.markerRadius, + show_non_visible=self.show_non_visible, + ) self.nodes[node.name] = node_item @@ -1206,8 +1376,13 @@ def __init__(self, skeleton:Skeleton = None, instance: Instance = None, for (src, dst) in self.skeleton.edge_names: # Make sure that both nodes are present in this instance before drawing edge if src in self.nodes and dst in self.nodes: - edge_item = QtEdge(parent=self, src=self.nodes[src], dst=self.nodes[dst], - color=self.color, show_non_visible=self.show_non_visible) + edge_item = QtEdge( + parent=self, + src=self.nodes[src], + dst=self.nodes[dst], + color=self.color, + show_non_visible=self.show_non_visible, + ) self.nodes[src].edges.append(edge_item) self.nodes[dst].edges.append(edge_item) self.edges.append(edge_item) @@ -1227,16 +1402,18 @@ def __init__(self, skeleton:Skeleton = None, instance: Instance = None, # Update size of box so it includes all the nodes/edges self.updateBox() - def updatePoints(self, complete:bool = False, user_change:bool = False): + def updatePoints(self, complete: bool = False, user_change: bool = False): """ Updates data and display for all points in skeleton. This is called any time the skeleton is manipulated as a whole. Args: - complete (optional): If set, we mark the state of all - nodes in the skeleton to "complete". - user_change (optional): Is this being called because of change by user? + complete: Whether to update all nodes by setting "completed" + attribute. + user_change: Whether method is called because of change made by + user. + Returns: None. """ @@ -1246,7 +1423,8 @@ def updatePoints(self, complete:bool = False, user_change:bool = False): node_item.point.x = node_item.scenePos().x() node_item.point.y = node_item.scenePos().y() node_item.setPos(node_item.point.x, node_item.point.y) - if complete: node_item.point.complete = True + if complete: + node_item.point.complete = True # Wait to run callbacks until all nodes are updated # Otherwise the label positions aren't correct since # they depend on the edge vectors to old node positions. @@ -1262,13 +1440,18 @@ def updatePoints(self, complete:bool = False, user_change:bool = False): # Update box for instance selection self.updateBox() # Emit event if we're updating from a user change - if user_change: self.changedData.emit(self.instance) + if user_change: + self.changedData.emit(self.instance) - def getPointsBoundingRect(self): + def getPointsBoundingRect(self) -> QRectF: """Returns a rect which contains all the nodes in the skeleton.""" rect = None for item in self.edges: - rect = item.boundingRect() if rect is None else rect.united(item.boundingRect()) + rect = ( + item.boundingRect() + if rect is None + else rect.united(item.boundingRect()) + ) return rect def updateBox(self, *args, **kwargs): @@ -1280,7 +1463,7 @@ def updateBox(self, *args, **kwargs): select this instance. """ # Only show box if instance is selected - op = .7 if self._selected else 0 + op = 0.7 if self._selected else 0 self.box.setOpacity(op) # Update the position for the box rect = self.getPointsBoundingRect() @@ -1293,10 +1476,12 @@ def updateBox(self, *args, **kwargs): @property def selected(self): + """Whether instance is selected.""" return self._selected @selected.setter - def selected(self, selected:bool): + def selected(self, selected: bool): + """Sets select-state for instance.""" self._selected = selected # Update the selection box for this skeleton instance self.updateBox() @@ -1306,7 +1491,7 @@ def toggleLabels(self): """ self.showLabels(not self.labels_shown) - def showLabels(self, show): + def showLabels(self, show: bool): """ Draws/hides the labels for this skeleton instance. @@ -1323,7 +1508,7 @@ def toggleEdges(self): """ self.showEdges(not self.edges_shown) - def showEdges(self, show = True): + def showEdges(self, show=True): """ Draws/hides the edges for this skeleton instance. @@ -1345,10 +1530,17 @@ def paint(self, painter, option, widget=None): """ pass + class QtTextWithBackground(QGraphicsTextItem): + """ + Inherits methods/behavior of `QGraphicsTextItem`, but with background box. + + Color of brackground box is light or dark depending on the text color. + """ def __init__(self, *args, **kwargs): super(QtTextWithBackground, self).__init__(*args, **kwargs) + self.setFlag(QGraphicsItem.ItemIgnoresTransformations) def boundingRect(self): """ Method required by Qt. @@ -1360,53 +1552,67 @@ def paint(self, painter, option, *args, **kwargs): """ text_color = self.defaultTextColor() brush = painter.brush() - background_color = "white" if text_color.lightnessF() < .4 else "black" - background_color = QColor(background_color, a=.5) + background_color = "white" if text_color.lightnessF() < 0.4 else "black" + background_color = QColor(background_color, a=0.5) painter.setBrush(QBrush(background_color)) painter.drawRect(self.boundingRect()) painter.setBrush(brush) super(QtTextWithBackground, self).paint(painter, option, *args, **kwargs) + def video_demo(labels, standalone=False): + """Demo function for showing (first) video from dataset.""" video = labels.videos[0] - if standalone: app = QApplication([]) + if standalone: + app = QApplication([]) window = QtVideoPlayer(video=video) - window.changedPlot.connect(lambda vp, idx, select_idx: plot_instances(vp.view.scene, idx, labels, video)) + window.changedPlot.connect( + lambda vp, idx, select_idx: plot_instances(vp.view.scene, idx, labels, video) + ) window.show() window.plot() - if standalone: app.exec_() + if standalone: + app.exec_() + def plot_instances(scene, frame_idx, labels, video=None, fixed=True): + """Demo function for plotting instances.""" + from sleap.gui.overlays.tracks import TrackColorManager + video = labels.videos[0] - color_manager = TrackColorManager(labels) - lfs = [label for label in labels.labels if label.video == video and label.frame_idx == frame_idx] + color_manager = TrackColorManager(labels=labels) + lfs = labels.find(video, frame_idx) - if len(lfs) == 0: return + if not lfs: + return labeled_frame = lfs[0] count_no_track = 0 for i, instance in enumerate(labeled_frame.instances_to_show): - if instance.track in self.labels.tracks: + if instance.track in labels.tracks: pseudo_track = instance.track else: # Instance without track - pseudo_track = len(self.labels.tracks) + count_no_track + pseudo_track = len(labels.tracks) + count_no_track count_no_track += 1 # Plot instance - inst = QtInstance(instance=instance, - color=color_manager(pseudo_track), - predicted=fixed, - color_predicted=True, - show_non_visible=False) + inst = QtInstance( + instance=instance, + color=color_manager.get_color(pseudo_track), + predicted=fixed, + color_predicted=True, + show_non_visible=False, + ) inst.showLabels(False) scene.addItem(inst) inst.updatePoints() + if __name__ == "__main__": import argparse @@ -1417,4 +1623,4 @@ def plot_instances(scene, frame_idx, labels, video=None, fixed=True): args = parser.parse_args() labels = Labels.load_json(args.data_path) - video_demo(labels, standalone=True) \ No newline at end of file + video_demo(labels, standalone=True) diff --git a/sleap/info/__init__.py b/sleap/info/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sleap/info/labels.py b/sleap/info/labels.py index 775725a75..e99d2b900 100644 --- a/sleap/info/labels.py +++ b/sleap/info/labels.py @@ -1,3 +1,6 @@ +""" +Command line utility which prints data about labels file. +""" import os from sleap.io.dataset import Labels @@ -17,6 +20,8 @@ print(f"Video files:") + total_user_frames = 0 + for vid in labels.videos: lfs = labels.find(vid) @@ -25,9 +30,15 @@ tracks = {inst.track for lf in lfs for inst in lf} concurrent_count = max((len(lf.instances) for lf in lfs)) + user_frames = len(labels.get_video_user_labeled_frames(vid)) + + total_user_frames += user_frames print(f" {vid.filename}") - print(f" labeled from {first_idx} to {last_idx}") + print(f" labeled frames from {first_idx} to {last_idx}") + print(f" labeled frames: {len(lfs)}") + print(f" user labeled frames: {user_frames}") print(f" tracks: {len(tracks)}") print(f" max instances in frame: {concurrent_count}") + print(f"Total user labeled frames: {total_user_frames}") diff --git a/sleap/info/metrics.py b/sleap/info/metrics.py index b3acddb8c..3412d2626 100644 --- a/sleap/info/metrics.py +++ b/sleap/info/metrics.py @@ -1,5 +1,7 @@ +""" +Module for producing prediction metrics for SLEAP datasets. +""" from inspect import signature -import itertools import numpy as np from scipy.optimize import linear_sum_assignment from typing import Callable, List, Optional, Union, Tuple @@ -7,12 +9,13 @@ from sleap.instance import Instance, PredictedInstance from sleap.io.dataset import Labels + def matched_instance_distances( - labels_gt: Labels, - labels_pr: Labels, - match_lists_function: Callable, - frame_range: Optional[range]=None) -> Tuple[ - List[int], np.ndarray, np.ndarray, np.ndarray]: + labels_gt: Labels, + labels_pr: Labels, + match_lists_function: Callable, + frame_range: Optional[range] = None, +) -> Tuple[List[int], np.ndarray, np.ndarray, np.ndarray]: """ Distances between ground truth and predicted nodes over a set of frames. @@ -62,7 +65,7 @@ def matched_instance_distances( points_gt.append(sorted_gt) points_pr.append(sorted_pr) - frame_idxs.extend([frame_idx]*len(sorted_gt)) + frame_idxs.extend([frame_idx] * len(sorted_gt)) # Convert arrays to numpy matrixes # instances * nodes * (x,y) @@ -75,26 +78,32 @@ def matched_instance_distances( return frame_idxs, D, points_gt, points_pr + def match_instance_lists( - instances_a: List[Union[Instance, PredictedInstance]], - instances_b: List[Union[Instance, PredictedInstance]], - cost_function: Callable) -> Tuple[ - List[Union[Instance, PredictedInstance]], - List[Union[Instance, PredictedInstance]]]: + instances_a: List[Union[Instance, PredictedInstance]], + instances_b: List[Union[Instance, PredictedInstance]], + cost_function: Callable, +) -> Tuple[ + List[Union[Instance, PredictedInstance]], List[Union[Instance, PredictedInstance]] +]: """Sorts two lists of Instances to find best overall correspondence for a given cost function (e.g., total distance between points).""" - pairwise_distance_matrix = calculate_pairwise_cost(instances_a, instances_b, cost_function) + pairwise_distance_matrix = calculate_pairwise_cost( + instances_a, instances_b, cost_function + ) match_a, match_b = linear_sum_assignment(pairwise_distance_matrix) sorted_a = list(map(lambda idx: instances_a[idx], match_a)) sorted_b = list(map(lambda idx: instances_b[idx], match_b)) return sorted_a, sorted_b + def calculate_pairwise_cost( - instances_a: List[Union[Instance, PredictedInstance]], - instances_b: List[Union[Instance, PredictedInstance]], - cost_function: Callable) -> np.ndarray: + instances_a: List[Union[Instance, PredictedInstance]], + instances_b: List[Union[Instance, PredictedInstance]], + cost_function: Callable, +) -> np.ndarray: """Calculate (a * b) matrix of pairwise costs using cost function.""" matrix_size = (len(instances_a), len(instances_b)) @@ -114,12 +123,14 @@ def calculate_pairwise_cost( pairwise_cost_matrix[idx_a, idx_b] = cost return pairwise_cost_matrix + def match_instance_lists_nodewise( - instances_a: List[Union[Instance, PredictedInstance]], - instances_b: List[Union[Instance, PredictedInstance]], - thresh: float=5) -> Tuple[ - List[Union[Instance, PredictedInstance]], - List[Union[Instance, PredictedInstance]]]: + instances_a: List[Union[Instance, PredictedInstance]], + instances_b: List[Union[Instance, PredictedInstance]], + thresh: float = 5, +) -> Tuple[ + List[Union[Instance, PredictedInstance]], List[Union[Instance, PredictedInstance]] +]: """For each node for each instance in the first list, pairs it with the closest corresponding node from *any* instance in the second list.""" @@ -141,8 +152,8 @@ def match_instance_lists_nodewise( for node_idx in range(node_count): # Make sure there's some prediction for this node - if any(~np.isnan(dist_array[:,node_idx])): - best_idx = np.nanargmin(dist_array[:,node_idx]) + if any(~np.isnan(dist_array[:, node_idx])): + best_idx = np.nanargmin(dist_array[:, node_idx]) # Ignore closest point if distance is beyond threshold if dist_array[best_idx, node_idx] <= thresh: @@ -153,26 +164,31 @@ def match_instance_lists_nodewise( return instances_a, best_points_array + def point_dist( - inst_a: Union[Instance, PredictedInstance], - inst_b: Union[Instance, PredictedInstance]) -> np.ndarray: + inst_a: Union[Instance, PredictedInstance], + inst_b: Union[Instance, PredictedInstance], +) -> np.ndarray: """Given two instances, returns array of distances for corresponding nodes.""" - points_a = inst_a.points_array(invisible_as_nan=True) - points_b = inst_b.points_array(invisible_as_nan=True) + points_a = inst_a.points_array + points_b = inst_b.points_array point_dist = np.linalg.norm(points_a - points_b, axis=1) return point_dist -def nodeless_point_dist(inst_a: Union[Instance, PredictedInstance], - inst_b: Union[Instance, PredictedInstance]) -> np.ndarray: + +def nodeless_point_dist( + inst_a: Union[Instance, PredictedInstance], + inst_b: Union[Instance, PredictedInstance], +) -> np.ndarray: """Given two instances, returns array of distances for closest points ignoring node identities.""" matrix_size = (len(inst_a.skeleton.nodes), len(inst_b.skeleton.nodes)) pairwise_distance_matrix = np.full(matrix_size, 0) - points_a = inst_a.points_array(invisible_as_nan=True) - points_b = inst_b.points_array(invisible_as_nan=True) + points_a = inst_a.points_array + points_b = inst_b.points_array # Calculate the distance between any pair of inst A and inst B points for idx_a in range(points_a.shape[0]): @@ -185,15 +201,17 @@ def nodeless_point_dist(inst_a: Union[Instance, PredictedInstance], match_a, match_b = linear_sum_assignment(pairwise_distance_matrix) # Sort points by this match and calculate overall distance - sorted_points_a = points_a[match_a,:] - sorted_points_b = points_b[match_b,:] + sorted_points_a = points_a[match_a, :] + sorted_points_b = points_b[match_b, :] point_dist = np.linalg.norm(points_a - points_b, axis=1) return point_dist + def compare_instance_lists( - instances_a: List[Union[Instance, PredictedInstance]], - instances_b: List[Union[Instance, PredictedInstance]]) -> np.ndarray: + instances_a: List[Union[Instance, PredictedInstance]], + instances_b: List[Union[Instance, PredictedInstance]], +) -> np.ndarray: """Given two lists of corresponding Instances, returns (instances * nodes) matrix of distances between corresponding nodes.""" @@ -203,29 +221,31 @@ def compare_instance_lists( return np.stack(paired_points_array_distances) -def list_points_array(instances: List[Union[Instance, PredictedInstance]]) -> np.ndarray: + +def list_points_array( + instances: List[Union[Instance, PredictedInstance]] +) -> np.ndarray: """Given list of Instances, returns (instances * nodes * 2) matrix.""" - points_arrays = list(map(lambda inst: inst.points_array(invisible_as_nan=True), instances)) + points_arrays = list(map(lambda inst: inst.points_array, instances)) return np.stack(points_arrays) -def point_match_count(dist_array: np.ndarray, thresh: float=5) -> int: + +def point_match_count(dist_array: np.ndarray, thresh: float = 5) -> int: """Given an array of distances, returns number which are <= threshold.""" return np.sum(dist_array[~np.isnan(dist_array)] <= thresh) -def point_nonmatch_count(dist_array: np.ndarray, thresh: float=5) -> int: + +def point_nonmatch_count(dist_array: np.ndarray, thresh: float = 5) -> int: """Given an array of distances, returns number which are not <= threshold.""" return dist_array.shape[0] - point_match_count(dist_array, thresh) -def foo(labels_gt, labels_pr, frame_idx=1092): - list_a = labels_gt.find(labels_gt.videos[0], frame_idx=frame_idx)[0].instances - list_b = labels_pr.find(labels_pr.videos[0], frame_idx=frame_idx)[0].instances - - match_instance_lists_nodewise(list_a, list_b) if __name__ == "__main__": labels_gt = Labels.load_json("tests/data/json_format_v1/centered_pair.json") - labels_pr = Labels.load_json("tests/data/json_format_v2/centered_pair_predictions.json") + labels_pr = Labels.load_json( + "tests/data/json_format_v2/centered_pair_predictions.json" + ) # OPTION 1 @@ -241,7 +261,9 @@ def foo(labels_gt, labels_pr, frame_idx=1092): # where "match" means the points are within some threshold distance. # Note that each sorted list will be as long as the shorted input list. - instwise_matching_func = lambda gt_list, pr_list: match_instance_lists(gt_list, pr_list, point_nonmatch_count) + instwise_matching_func = lambda gt_list, pr_list: match_instance_lists( + gt_list, pr_list, point_nonmatch_count + ) # PICK THE FUNCTION @@ -249,7 +271,9 @@ def foo(labels_gt, labels_pr, frame_idx=1092): # inst_matching_func = instwise_matching_func # Calculate distances - frame_idxs, D, points_gt, points_pr = matched_instance_distances(labels_gt, labels_pr, inst_matching_func) + frame_idxs, D, points_gt, points_pr = matched_instance_distances( + labels_gt, labels_pr, inst_matching_func + ) # Show mean difference for each node node_names = labels_gt.skeletons[0].node_names diff --git a/sleap/info/summary.py b/sleap/info/summary.py new file mode 100644 index 000000000..b907aeb09 --- /dev/null +++ b/sleap/info/summary.py @@ -0,0 +1,144 @@ +""" +Module for getting a series which gives some statistic based on labeling +data for each frame of some labeled video. +""" + +import attr +import numpy as np + +from typing import Callable, Dict + +from sleap.io.dataset import Labels +from sleap.io.video import Video + + +@attr.s(auto_attribs=True) +class StatisticSeries: + """ + Class to calculate various statistical series for labeled frames. + + Each method returns a series which is a dictionary in which keys + are frame index and value are some numerical value for the frame. + + Args: + labels: The :class:`Labels` for which to calculate series. + """ + + labels: Labels + + def get_point_count_series(self, video: Video) -> Dict[int, float]: + """Get series with total number of labeled points in each frame.""" + series = dict() + + for lf in self.labels.find(video): + val = sum(len(inst.points) for inst in lf if hasattr(inst, "score")) + series[lf.frame_idx] = val + return series + + def get_point_score_series( + self, video: Video, reduction: str = "sum" + ) -> Dict[int, float]: + """Get series with statistic of point scores in each frame. + + Args: + video: The :class:`Video` for which to calculate statistic. + reduction: name of function applied to scores: + * sum + * min + + Returns: + The series dictionary (see class docs for details) + """ + reduce_funct = dict(sum=sum, min=lambda x: min(x, default=0))[reduction] + + series = dict() + + for lf in self.labels.find(video): + val = reduce_funct( + point.score + for inst in lf + for point in inst.points + if hasattr(inst, "score") + ) + series[lf.frame_idx] = val + return series + + def get_instance_score_series(self, video, reduction="sum") -> Dict[int, float]: + """Get series with statistic of instance scores in each frame. + + Args: + video: The :class:`Video` for which to calculate statistic. + reduction: name of function applied to scores: + * sum + * min + + Returns: + The series dictionary (see class docs for details) + """ + reduce_funct = dict(sum=sum, min=lambda x: min(x, default=0))[reduction] + + series = dict() + + for lf in self.labels.find(video): + val = reduce_funct(inst.score for inst in lf if hasattr(inst, "score")) + series[lf.frame_idx] = val + return series + + def get_point_displacement_series(self, video, reduction="sum") -> Dict[int, float]: + """ + Get series with statistic of point displacement in each frame. + + Point displacement is the distance between the point location in + frame and the location of the corresponding point (same node, + same track) from the closest earlier labeled frame. + + Args: + video: The :class:`Video` for which to calculate statistic. + reduction: name of function applied to point scores: + * sum + * mean + * max + + Returns: + The series dictionary (see class docs for details) + """ + reduce_funct = dict(sum=np.sum, mean=np.nanmean, max=np.max)[reduction] + + series = dict() + + last_lf = None + for lf in self.labels.find(video): + val = self._calculate_frame_velocity(lf, last_lf, reduce_funct) + last_lf = lf + if not np.isnan(val): + series[lf.frame_idx] = val # len(lf.instances) + return series + + @staticmethod + def _calculate_frame_velocity( + lf: "LabeledFrame", last_lf: "LabeledFrame", reduce_function: Callable + ) -> float: + """ + Calculate total point displacement between two given frames. + + Args: + lf: The :class:`LabeledFrame` for which we want velocity + last_lf: The frame from which to calculate displacement. + reduce_function: Numpy function (e.g., np.sum, np.nanmean) + is applied to *point* displacement, and then those + instance values are summed for the whole frame. + + Returns: + The total velocity for instances in frame. + """ + val = 0 + for inst in lf: + if last_lf is not None: + last_inst = last_lf.find(track=inst.track) + if last_inst: + points_a = inst.points_array + points_b = last_inst[0].points_array + point_dist = np.linalg.norm(points_a - points_b, axis=1) + inst_dist = reduce_function(point_dist) + val += inst_dist if not np.isnan(inst_dist) else 0 + return val diff --git a/sleap/info/write_tracking_h5.py b/sleap/info/write_tracking_h5.py index fe312a3dd..a91a815fa 100644 --- a/sleap/info/write_tracking_h5.py +++ b/sleap/info/write_tracking_h5.py @@ -1,87 +1,221 @@ +""" +Generate an HDF5 file with track occupancy and point location data. + +Ignores tracks that are entirely empty. By default will also ignore +empty frames from the beginning and end of video, although +`--all-frames` argument will make it include empty frames from beginning +of video. + +Call from command line as: + +>>> python -m sleap.io.write_tracking_h5 + +Will write file to `.tracking.h5`. + +The HDF5 file has these datasets: +* "track_occupancy" shape: tracks * frames +* "tracks" shape: frames * nodes * 2 * tracks +* "track_names" shape: tracks + +Note: the datasets are stored column-major as expected by MATLAB. +""" + import os import re import h5py as h5 import numpy as np +from typing import Any, Dict, List, Tuple + from sleap.io.dataset import Labels -if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("data_path", help="Path to labels json file") - args = parser.parse_args() +def get_tracks_as_np_strings(labels: Labels) -> List[np.string_]: + """Get list of track names as `np.string_`.""" + return [np.string_(track.name) for track in labels.tracks] - def video_callback(video_list, new_paths=[os.path.dirname(args.data_path)]): - # Check each video - for video_item in video_list: - if "backend" in video_item and "filename" in video_item["backend"]: - current_filename = video_item["backend"]["filename"] - # check if we can find video - if not os.path.exists(current_filename): - is_found = False - - current_basename = os.path.basename(current_filename) - # handle unix, windows, or mixed paths - if current_basename.find("/") > -1: - current_basename = current_basename.split("/")[-1] - if current_basename.find("\\") > -1: - current_basename = current_basename.split("\\")[-1] - - # First see if we can find the file in another directory, - # and if not, prompt the user to find the file. - - # We'll check in the current working directory, and if the user has - # already found any missing videos, check in the directory of those. - for path_dir in new_paths: - check_path = os.path.join(path_dir, current_basename) - if os.path.exists(check_path): - # we found the file in a different directory - video_item["backend"]["filename"] = check_path - is_found = True - break - labels = Labels.load_file(args.data_path, video_callback=video_callback) +def get_occupancy_and_points_matrices( + labels: Labels, all_frames: bool +) -> Tuple[np.ndarray, np.ndarray]: + """ + Builds numpy matrices with track occupancy and point location data. + + Args: + labels: The :class:`Labels` from which to get data. + all_frames: If True, then includes zeros so that frame index + will line up with columns in the output. Otherwise, + there will only be columns for the frames between the + first and last frames with labeling data. + + Returns: + tuple of two matrices: - frame_count = len(labels) + * occupancy matrix with shape (tracks, frames) + * point location matrix with shape (frames, nodes, 2, tracks) + """ track_count = len(labels.tracks) - track_names = [np.string_(track.name) for track in labels.tracks] node_count = len(labels.skeletons[0].nodes) frame_idxs = [lf.frame_idx for lf in labels] frame_idxs.sort() + first_frame_idx = 0 if all_frames else frame_idxs[0] + + frame_count = ( + frame_idxs[-1] - first_frame_idx + 1 + ) # count should include unlabeled frames + # Desired MATLAB format: # "track_occupancy" tracks * frames # "tracks" frames * nodes * 2 * tracks # "track_names" tracks occupancy_matrix = np.zeros((track_count, frame_count), dtype=np.uint8) - prediction_matrix = np.full((frame_count, node_count, 2, track_count), np.nan, dtype=float) - + locations_matrix = np.full( + (frame_count, node_count, 2, track_count), np.nan, dtype=float + ) + for lf, inst in [(lf, inst) for lf in labels for inst in lf.instances]: - frame_i = frame_idxs.index(lf.frame_idx) + frame_i = lf.frame_idx - first_frame_idx track_i = labels.tracks.index(inst.track) occupancy_matrix[track_i, frame_i] = 1 - inst_points = inst.points_array(invisible_as_nan=True) - prediction_matrix[frame_i, ..., track_i] = inst_points - - print(f"track_occupancy: {occupancy_matrix.shape}") - print(f"tracks: {prediction_matrix.shape}") - - output_filename = re.sub("(\.json(\.zip)?|\.h5)$", "", args.data_path) - output_filename = output_filename + ".tracking.h5" - - with h5.File(output_filename, "w") as f: - # We have to transpose the arrays since MATLAB expects column-major - ds = f.create_dataset("track_names", data=track_names) - ds = f.create_dataset( - "track_occupancy", data=np.transpose(occupancy_matrix), - compression="gzip", compression_opts=9) - ds = f.create_dataset( - "tracks", data=np.transpose(prediction_matrix), - compression="gzip", compression_opts=9) - - print(f"Saved as {output_filename}") \ No newline at end of file + inst_points = inst.points_array + locations_matrix[frame_i, ..., track_i] = inst_points + + return occupancy_matrix, locations_matrix + + +def remove_empty_tracks_from_matrices( + track_names: List, occupancy_matrix: np.ndarray, locations_matrix: np.ndarray +) -> Tuple[List, np.ndarray, np.ndarray]: + """ + Removes matrix rows/columns for unoccupied tracks. + + Args: + track_names: List of track names + occupancy_matrix: 2d numpy matrix, rows correspond to tracks + locations_matrix: 4d numpy matrix, last index is track + + Returns: + track_names, occupancy_matrix, locations_matrix from input, + but without the rows/columns corresponding to unoccupied tracks. + """ + # Make mask with only the occupied tracks + occupied_track_mask = np.sum(occupancy_matrix, axis=1) > 0 + + # Ignore unoccupied tracks + if np.sum(~occupied_track_mask): + + print(f"ignoring {np.sum(~occupied_track_mask)} empty tracks") + + occupancy_matrix = occupancy_matrix[occupied_track_mask] + locations_matrix = locations_matrix[..., occupied_track_mask] + track_names = [ + track_names[i] for i in range(len(track_names)) if occupied_track_mask[i] + ] + + return track_names, occupancy_matrix, locations_matrix + + +def write_occupancy_file( + output_path: str, data_dict: Dict[str, Any], transpose: bool = True +): + """ + Write HDF5 file with data from given dictionary. + + Args: + output_path: Path of HDF5 file. + data_dict: Dictionary with data to save. Keys are dataset names, + values are the data. + transpose: If True, then any ndarray in data dictionary will be + transposed before saving. This is useful for writing files + that will be imported into MATLAB, which expects data in + column-major format. + + Returns: + None + """ + + with h5.File(output_path, "w") as f: + for key, val in data_dict.items(): + if isinstance(val, np.ndarray): + print(f"{key}: {val.shape}") + + if transpose: + # Transpose since MATLAB expects column-major + f.create_dataset( + key, + data=np.transpose(val), + compression="gzip", + compression_opts=9, + ) + else: + f.create_dataset( + key, data=val, compression="gzip", compression_opts=9 + ) + else: + print(f"{key}: {len(val)}") + f.create_dataset(key, data=val) + + print(f"Saved as {output_path}") + + +def main(labels: Labels, output_path: str, all_frames: bool = True): + """ + Writes HDF5 file with matrices of track occupancy and coordinates. + + Args: + labels: The :class:`Labels` from which to get data. + output_path: Path of HDF5 file to create. + all_frames: If True, then includes zeros so that frame index + will line up with columns in the output. Otherwise, + there will only be columns for the frames between the + first and last frames with labeling data. + + Returns: + None + """ + track_names = get_tracks_as_np_strings(labels) + + occupancy_matrix, locations_matrix = get_occupancy_and_points_matrices( + labels, all_frames + ) + + track_names, occupancy_matrix, locations_matrix = remove_empty_tracks_from_matrices( + track_names, occupancy_matrix, locations_matrix + ) + + data_dict = dict( + track_names=track_names, + tracks=locations_matrix, + track_occupancy=occupancy_matrix, + ) + + write_occupancy_file(output_path, data_dict, transpose=True) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("data_path", help="Path to labels json file") + parser.add_argument( + "--all-frames", + dest="all_frames", + action="store_const", + const=True, + default=False, + help="include all frames without predictions", + ) + args = parser.parse_args() + + video_callback = Labels.make_video_callback([os.path.dirname(args.data_path)]) + labels = Labels.load_file(args.data_path, video_callback=video_callback) + + output_path = re.sub("(\.json(\.zip)?|\.h5)$", "", args.data_path) + output_path = output_path + ".tracking.h5" + + main(labels, output_path=output_path, all_frames=args.all_frames) diff --git a/sleap/instance.py b/sleap/instance.py index d03f8cfd0..9ca1d2c53 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -1,14 +1,28 @@ """ +Data structures for all labeled data contained with a SLEAP project. +The relationships between objects in this module: + +* A `LabeledFrame` can contain zero or more `Instance`s + (and `PredictedInstance` objects). + +* `Instance` objects (and `PredictedInstance` objects) have `PointArray` + (or `PredictedPointArray`). + +* `Instance` (`PredictedInstance`) can be associated with a `Track` + +* A `PointArray` (or `PredictedPointArray`) contains zero or more + `Point` objects (or `PredictedPoint` objectss), ideally as many as + there are in the associated :class:`Skeleton` although these can get + out of sync if the skeleton is manipulated. """ import math import numpy as np -import h5py as h5 -import pandas as pd import cattr +from copy import copy from typing import Dict, List, Optional, Union, Tuple from numpy.lib.recfunctions import structured_to_unstructured @@ -26,26 +40,27 @@ class Point(np.record): """ - A very simple class to define a labelled point and any metadata associated with it. + A labelled point and any metadata associated with it. Args: - x: The horizontal pixel location of the point within the image frame. - y: The vertical pixel location of the point within the image frame. + x: The horizontal pixel location of point within image frame. + y: The vertical pixel location of point within image frame. visible: Whether point is visible in the labelled image or not. - complete: Has the point been verified by the a user labeler. + complete: Has the point been verified by the user labeler. """ # Define the dtype from the point class attributes plus some # additional fields we will use to relate point to instances and # nodes. - dtype = np.dtype( - [('x', 'f8'), - ('y', 'f8'), - ('visible', '?'), - ('complete', '?')]) + dtype = np.dtype([("x", "f8"), ("y", "f8"), ("visible", "?"), ("complete", "?")]) - def __new__(cls, x: float = math.nan, y: float = math.nan, - visible: bool = True, complete: bool = False): + def __new__( + cls, + x: float = math.nan, + y: float = math.nan, + visible: bool = True, + complete: bool = False, + ) -> "Point": # HACK: This is a crazy way to instantiate at new Point but I can't figure # out how recarray does it. So I just use it to make matrix of size 1 and @@ -62,10 +77,10 @@ def __new__(cls, x: float = math.nan, y: float = math.nan, return val - def __str__(self): + def __str__(self) -> str: return f"({self.x}, {self.y})" - def isnan(self): + def isnan(self) -> bool: """ Are either of the coordinates a NaN value. @@ -77,37 +92,38 @@ def isnan(self): # This turns PredictedPoint into an attrs class. Defines comparators for # us and generaly makes it behave better. Crazy that this works! -Point = attr.s(these={name: attr.ib() - for name in Point.dtype.names}, - init=False)(Point) +Point = attr.s(these={name: attr.ib() for name in Point.dtype.names}, init=False)(Point) class PredictedPoint(Point): """ - A predicted point is an output of the inference procedure. It has all - the properties of a labeled point with an accompanying score. + A predicted point is an output of the inference procedure. + + It has all the properties of a labeled point, plus a score. Args: - x: The horizontal pixel location of the point within the image frame. - y: The vertical pixel location of the point within the image frame. + x: The horizontal pixel location of point within image frame. + y: The vertical pixel location of point within image frame. visible: Whether point is visible in the labelled image or not. - complete: Has the point been verified by the a user labeler. - score: The point level prediction score. + complete: Has the point been verified by the user labeler. + score: The point-level prediction score. """ # Define the dtype from the point class attributes plus some # additional fields we will use to relate point to instances and # nodes. dtype = np.dtype( - [('x', 'f8'), - ('y', 'f8'), - ('visible', '?'), - ('complete', '?'), - ('score', 'f8')]) - - def __new__(cls, x: float = math.nan, y: float = math.nan, - visible: bool = True, complete: bool = False, - score: float = 0.0): + [("x", "f8"), ("y", "f8"), ("visible", "?"), ("complete", "?"), ("score", "f8")] + ) + + def __new__( + cls, + x: float = math.nan, + y: float = math.nan, + visible: bool = True, + complete: bool = False, + score: float = 0.0, + ) -> "PredictedPoint": # HACK: This is a crazy way to instantiate at new Point but I can't figure # out how recarray does it. So I just use it to make matrix of size 1 and @@ -126,7 +142,7 @@ def __new__(cls, x: float = math.nan, y: float = math.nan, return val @classmethod - def from_point(cls, point: Point, score: float = 0.0): + def from_point(cls, point: Point, score: float = 0.0) -> "PredictedPoint": """ Create a PredictedPoint from a Point @@ -137,14 +153,14 @@ def from_point(cls, point: Point, score: float = 0.0): Returns: A scored point based on the point passed in. """ - return cls(**{**Point.asdict(point), 'score': score}) + return cls(**{**Point.asdict(point), "score": score}) # This turns PredictedPoint into an attrs class. Defines comparators for # us and generaly makes it behave better. Crazy that this works! -PredictedPoint = attr.s(these={name: attr.ib() - for name in PredictedPoint.dtype.names}, - init=False)(PredictedPoint) +PredictedPoint = attr.s( + these={name: attr.ib() for name in PredictedPoint.dtype.names}, init=False +)(PredictedPoint) class PointArray(np.recarray): @@ -155,9 +171,19 @@ class PointArray(np.recarray): _record_type = Point - def __new__(subtype, shape, buf=None, offset=0, strides=None, - formats=None, names=None, titles=None, - byteorder=None, aligned=False, order='C'): + def __new__( + subtype, + shape, + buf=None, + offset=0, + strides=None, + formats=None, + names=None, + titles=None, + byteorder=None, + aligned=False, + order="C", + ) -> "PointArray": dtype = subtype._record_type.dtype @@ -167,25 +193,38 @@ def __new__(subtype, shape, buf=None, offset=0, strides=None, descr = np.format_parser(formats, names, titles, aligned, byteorder)._descr if buf is None: - self = np.ndarray.__new__(subtype, shape, (subtype._record_type, descr), order=order) + self = np.ndarray.__new__( + subtype, shape, (subtype._record_type, descr), order=order + ) else: - self = np.ndarray.__new__(subtype, shape, (subtype._record_type, descr), - buffer=buf, offset=offset, - strides=strides, order=order) + self = np.ndarray.__new__( + subtype, + shape, + (subtype._record_type, descr), + buffer=buf, + offset=offset, + strides=strides, + order=order, + ) return self def __array_finalize__(self, obj): """ - Overide __array_finalize__ on recarray because it converting the dtype - of any np.void subclass to np.record, we don't want this. + Override :method:`np.recarray.__array_finalize__()`. + + Overide __array_finalize__ on recarray because it converting the + dtype of any np.void subclass to np.record, we don't want this. """ pass @classmethod - def make_default(cls, size: int): + def make_default(cls, size: int) -> "PointArray": """ - Construct a point array of specific size where each value in the array - is assigned the default values for a Point. + Construct a point array where points are all set to default. + + The constructed :class:`PointArray` will have specified size + and each value in the array is assigned the default values for + a :class:`Point``. Args: size: The number of points to allocate. @@ -197,7 +236,8 @@ def make_default(cls, size: int): p[:] = cls._record_type() return p - def __getitem__(self, indx): + def __getitem__(self, indx: int) -> "Point": + """Get point by its index in the array.""" obj = super(np.recarray, self).__getitem__(indx) # copy behavior of getattr, except that here @@ -205,7 +245,7 @@ def __getitem__(self, indx): if isinstance(obj, np.ndarray): if obj.dtype.fields: obj = obj.view(type(self)) - #if issubclass(obj.dtype.type, numpy.void): + # if issubclass(obj.dtype.type, numpy.void): # return obj.view(dtype=(self.dtype.type, obj.dtype)) return obj else: @@ -215,17 +255,21 @@ def __getitem__(self, indx): return obj @classmethod - def from_array(cls, a: 'PointArray'): + def from_array(cls, a: "PointArray") -> "PointArray": """ - Convert a PointArray to a new PointArray - (or child class, i.e., PredictedPointArray), - use the default attribute values for new array. + Converts a :class:`PointArray` (or child) to a new instance. + + This will convert an object to the same type as itself, + so a :class:`PredictedPointArray` will result in the same. + + Uses the default attribute values for new array. Args: a: The array to convert. Returns: - A PredictedPointArray with the same points as a. + A :class:`PointArray` or :class:`PredictedPointArray` with + the same points as a. """ v = cls.make_default(len(a)) @@ -234,15 +278,17 @@ def from_array(cls, a: 'PointArray'): return v + class PredictedPointArray(PointArray): """ PredictedPointArray is analogous to PointArray except for predicted points. """ + _record_type = PredictedPoint @classmethod - def to_array(cls, a: 'PredictedPointArray'): + def to_array(cls, a: "PredictedPointArray") -> "PointArray": """ Convert a PredictedPointArray to a normal PointArray. @@ -263,18 +309,19 @@ def to_array(cls, a: 'PredictedPointArray'): @attr.s(slots=True, cmp=False) class Track: """ - A track object is associated with a set of animal/object instances across multiple - frames of video. This allows tracking of unique entities in the video over time and - space. + A track object is associated with a set of animal/object instances + across multiple frames of video. This allows tracking of unique + entities in the video over time and space. Args: - spawned_on: The frame of the video that this track was spawned on. + spawned_on: The video frame that this track was spawned on. name: A name given to this track for identifying purposes. """ + spawned_on: int = attr.ib(converter=int) name: str = attr.ib(default="", converter=str) - def matches(self, other: 'Track'): + def matches(self, other: "Track"): """ Check if two tracks match by value. @@ -292,66 +339,118 @@ def matches(self, other: 'Track'): # attributes _frame and _point_array_cache after init. These are private variables # that are created in post init so they are not serialized. + @attr.s(cmp=False, slots=True) class Instance: """ - The class :class:`Instance` represents a labelled instance of skeleton + The class :class:`Instance` represents a labelled instance of a skeleton. Args: skeleton: The skeleton that this instance is associated with. - points: A dictionary where keys are skeleton node names and values are Point objects. Alternatively, - a point array whose length and order matches skeleton.nodes - track: An optional multi-frame object track associated with this instance. - This allows individual animals/objects to be tracked across frames. - from_predicted: The predicted instance (if any) that this was copied from. - frame: A back reference to the LabeledFrame that this Instance belongs to. - This field is set when Instances are added to LabeledFrame objects. + points: A dictionary where keys are skeleton node names and + values are Point objects. Alternatively, a point array whose + length and order matches skeleton.nodes. + track: An optional multi-frame object track associated with + this instance. This allows individual animals/objects to be + tracked across frames. + from_predicted: The predicted instance (if any) that this was + copied from. + frame: A back reference to the :class:`LabeledFrame` that this + :class:`Instance` belongs to. This field is set when + instances are added to :class:`LabeledFrame` objects. """ skeleton: Skeleton = attr.ib() track: Track = attr.ib(default=None) - from_predicted: Optional['PredictedInstance'] = attr.ib(default=None) + from_predicted: Optional["PredictedInstance"] = attr.ib(default=None) _points: PointArray = attr.ib(default=None) _nodes: List = attr.ib(default=None) - frame: Union['LabeledFrame', None] = attr.ib(default=None) + frame: Union["LabeledFrame", None] = attr.ib(default=None) # The underlying Point array type that this instances point array should be. _point_array_type = PointArray @from_predicted.validator - def _validate_from_predicted_(self, attribute, from_predicted): + def _validate_from_predicted_( + self, attribute, from_predicted: Optional["PredictedInstance"] + ): + """ + Validation method called by attrs. + + Checks that from_predicted is None or :class:`PredictedInstance` + + Args: + attribute: Attribute being validated; not used. + from_predicted: Value being validated. + + Raises: + TypeError: If from_predicted is anything other than None + or a `PredictedInstance`. + + """ if from_predicted is not None and type(from_predicted) != PredictedInstance: - raise TypeError(f"Instance.from_predicted type must be PredictedInstance (not {type(from_predicted)})") + raise TypeError( + f"Instance.from_predicted type must be PredictedInstance (not {type(from_predicted)})" + ) @_points.validator - def _validate_all_points(self, attribute, points): + def _validate_all_points(self, attribute, points: Union[dict, PointArray]): """ - Function that makes sure all the _points defined for the skeleton are found in the skeleton. + Validation method called by attrs. - Returns: - None + Checks that all the _points defined for the skeleton are found + in the skeleton. + + Args: + attribute: Attribute being validated; not used. + points: Either dict of points or PointArray + If dict, keys should be node names. Raises: - ValueError: If a point is associated with a skeleton node name that doesn't exist. + ValueError: If a point is associated with a skeleton node + name that doesn't exist. + + Returns: + None """ if type(points) is dict: is_string_dict = set(map(type, points)) == {str} if is_string_dict: for node_name in points.keys(): if not self.skeleton.has_node(node_name): - raise KeyError(f"There is no node named {node_name} in {self.skeleton}") + raise KeyError( + f"There is no node named {node_name} in {self.skeleton}" + ) elif isinstance(points, PointArray): if len(points) != len(self.skeleton.nodes): - raise ValueError("PointArray does not have the same number of rows as skeleton nodes.") + raise ValueError( + "PointArray does not have the same number of rows as skeleton nodes." + ) def __attrs_post_init__(self): + """ + Method called by attrs after __init__() + + Initializes points if none were specified when creating object, + caches list of nodes so what we can still find points in array + if the `Skeleton` changes. + + Args: + None + + Raises: + ValueError: If object has no `Skeleton`. + + Returns: + None + """ if not self.skeleton: raise ValueError("No skeleton set for Instance") # If the user did not pass a points list initialize a point array for future # points. - if self._points is None: + if self._points is None or len(self._points) == 0: # Initialize an empty point array that is the size of the skeleton. self._points = self._point_array_type.make_default(len(self.skeleton.nodes)) @@ -369,7 +468,26 @@ def __attrs_post_init__(self): self._nodes = self.skeleton.nodes @staticmethod - def _points_dict_to_array(points, parray, skeleton): + def _points_dict_to_array( + points: Dict[Union[str, Node], Point], parray: PointArray, skeleton: Skeleton + ): + """ + Sets values in given :class:`PointsArray` from dictionary. + + Args: + points: The dictionary of points. Keys can be either node + names or :class:`Node`s, values are :class:`Point`s. + parray: The :class:`PointsArray` which is being updated. + skeleton: The :class:`Skeleton` which contains the nodes + referenced in the dictionary of points. + + Raises: + ValueError: If dictionary keys are not either all strings + or all :class:`Node`s. + + Returns: + None + """ # Check if the dict contains all strings is_string_dict = set(map(type, points)) == {str} @@ -384,41 +502,53 @@ def _points_dict_to_array(points, parray, skeleton): points = {skeleton.find_node(name): point for name, point in points.items()} if not is_string_dict and not is_node_dict: - raise ValueError("points dictionary must be keyed by either strings " + - "(node names) or Nodes.") + raise ValueError( + "points dictionary must be keyed by either strings " + + "(node names) or Nodes." + ) # Get rid of the points dict and replace with equivalent point array. for node, point in points.items(): # Convert PredictedPoint to Point if Instance if type(parray) == PointArray and type(point) == PredictedPoint: - point = Point(x=point.x, y=point.y, visible=point.visible, complete=point.complete) + point = Point( + x=point.x, y=point.y, visible=point.visible, complete=point.complete + ) try: parray[skeleton.node_to_index(node)] = point # parray[skeleton.node_to_index(node.name)] = point except: pass - def _node_to_index(self, node_name): + def _node_to_index(self, node: Union[str, Node]) -> int: """ Helper method to get the index of a node from its name. Args: - node_name: The name of the node. + node: Node name or :class:`Node` object. Returns: The index of the node on skeleton graph. """ - return self.skeleton.node_to_index(node_name) + return self.skeleton.node_to_index(node) - def __getitem__(self, node): + def __getitem__( + self, node: Union[List[Union[str, Node]], Union[str, Node]] + ) -> Union[List[Point], Point]: """ - Get the Points associated with particular skeleton node or list of skeleton nodes + Get the Points associated with particular skeleton node(s). Args: - node: A single node or list of nodes within the skeleton associated with this instance. + node: A single node or list of nodes within the skeleton + associated with this instance. + + Raises: + KeyError: If node cannot be found in skeleton. Returns: - A single point of list of _points related to the nodes provided as argument. + Either a single point (if a single node given), or + a list of points (if a list of nodes given) corresponding + to each node. """ @@ -434,20 +564,23 @@ def __getitem__(self, node): node = self._node_to_index(node) return self._points[node] except ValueError: - raise KeyError(f"The underlying skeleton ({self.skeleton}) has no node '{node}'") + raise KeyError( + f"The underlying skeleton ({self.skeleton}) has no node '{node}'" + ) - def __contains__(self, node): + def __contains__(self, node: Union[str, Node]) -> bool: """ - Returns True if this instance has a point with the specified node. + Whether this instance has a point with the specified node. Args: - node: node name + node: Node name or :class:`Node` object. Returns: - bool: True if the point with the node name specified has a point in this instance. + bool: True if the point with the node name specified has a + point in this instance. """ - if type(node) is Node: + if isinstance(node, Node): node = node.name if node not in self.skeleton: @@ -458,14 +591,37 @@ def __contains__(self, node): # If the points are nan, then they haven't been allocated. return not self._points[node_idx].isnan() - def __setitem__(self, node, value): + def __setitem__( + self, + node: Union[List[Union[str, Node]], Union[str, Node]], + value: Union[List[Point], Point], + ): + """ + Set the point(s) for given node(s). + + Args: + node: Either node (by name or `Node`) or list of nodes. + value: Either `Point` or list of `Point`s. + + Raises: + IndexError: If lengths of lists don't match, or if exactly + one of the inputs is a list. + KeyError: If skeleton does not have (one of) the node(s). + + Returns: + None + """ # Make sure node and value, if either are lists, are of compatible size if type(node) is not list and type(value) is list and len(value) != 1: - raise IndexError("Node list for indexing must be same length and value list.") + raise IndexError( + "Node list for indexing must be same length and value list." + ) if type(node) is list and type(value) is not list and len(node) != 1: - raise IndexError("Node list for indexing must be same length and value list.") + raise IndexError( + "Node list for indexing must be same length and value list." + ) # If we are dealing with lists, do multiple assignment recursively, this should be ok because # skeletons and instances are small. @@ -477,23 +633,40 @@ def __setitem__(self, node, value): node_idx = self._node_to_index(node) self._points[node_idx] = value except ValueError: - raise KeyError(f"The underlying skeleton ({self.skeleton}) has no node '{node}'") + raise KeyError( + f"The underlying skeleton ({self.skeleton}) has no node '{node}'" + ) + + def __delitem__(self, node: Union[str, Node]): + """ + Delete node key and points associated with that node. - def __delitem__(self, node): - """ Delete node key and points associated with that node. """ + Args: + node: Node name or :class:`Node` object. + + Raises: + KeyError: If skeleton does not have the node. + + Returns: + None + """ try: node_idx = self._node_to_index(node) self._points[node_idx].x = math.nan self._points[node_idx].y = math.nan except ValueError: - raise KeyError(f"The underlying skeleton ({self.skeleton}) has no node '{node}'") + raise KeyError( + f"The underlying skeleton ({self.skeleton}) has no node '{node}'" + ) - def matches(self, other): + def matches(self, other: "Instance") -> bool: """ - Compare this `Instance` to another, modulo the particular `Node` objects. + Whether two instances match by value. + + Checks the types, points, track, and frame index. Args: - other: The other instance. + other: The other :class:`Instance`. Returns: True if match, False otherwise. @@ -501,7 +674,7 @@ def matches(self, other): if type(self) is not type(other): return False - if list(self.points()) != list(other.points()): + if list(self.points) != list(other.points): return False if not self.skeleton.matches(other.skeleton): @@ -520,42 +693,44 @@ def matches(self, other): return True @property - def nodes(self): + def nodes(self) -> Tuple[Node, ...]: """ - Get the list of nodes that have been labelled for this instance. - - Returns: - A tuple of nodes that have been labelled for this instance. - + The tuple of nodes that have been labelled for this instance. """ - self.fix_array() - return tuple(self._nodes[i] for i, point in enumerate(self._points) - if not point.isnan() and self._nodes[i] in self.skeleton.nodes) + self._fix_array() + return tuple( + self._nodes[i] + for i, point in enumerate(self._points) + if not point.isnan() and self._nodes[i] in self.skeleton.nodes + ) @property - def nodes_points(self): + def nodes_points(self) -> List[Tuple[Node, Point]]: """ - Return view object that displays a list of the instance's (node, point) tuple pairs - for all labelled point. - - Returns: - The instance's (node, point) tuple pairs for all labelled point. + The list of (node, point) tuples for all labelled points. """ - names_to_points = dict(zip(self.nodes, self.points())) + names_to_points = dict(zip(self.nodes, self.points)) return names_to_points.items() - def points(self) -> Tuple[Point]: + @property + def points(self) -> Tuple[Point, ...]: + """ + The tuple of labelled points, in order they were labelled. + """ + self._fix_array() + return tuple(point for point in self._points if not point.isnan()) + + def _fix_array(self): """ - Return the list of labelled points, in order they were labelled. + Fixes PointArray after nodes have been added or removed. + + This updates the PointArray as required by comparing the cached + list of nodes to the nodes in the `Skeleton` object (which may + have changed). Returns: - The list of labelled points, in order they were labelled. + None """ - self.fix_array() - return tuple(point for point in self._points if not point.isnan()) - - def fix_array(self): - """Fix points array after nodes have been added or removed.""" # Check if cached skeleton nodes are different than current nodes if self._nodes != self.skeleton.nodes: @@ -572,56 +747,73 @@ def fix_array(self): self._points = new_array self._nodes = self.skeleton.nodes - def points_array(self, copy: bool = True, - invisible_as_nan: bool = False, - full: bool = False) -> np.ndarray: + def get_points_array( + self, copy: bool = True, invisible_as_nan: bool = False, full: bool = False + ) -> np.ndarray: """ Return the instance's points in array form. Args: copy: If True, the return a copy of the points array as an - Nx2 ndarray where first column is x and second column is y. - If False, return a view of the underlying recarray. + Nx2 ndarray where first column is x and second is y. + If False, return a view of the underlying recarray. invisible_as_nan: Should invisible points be marked as NaN. - full: If True, return the raw underlying recarray with all attributes - of the point, if not, return just the x and y coordinate. Assumes - copy is False and invisible_as_nan is False. + full: If True, return the raw underlying recarray with all + attributes of the point. + Otherwise, return just the x and y coordinate. + Assumes copy is False and invisible_as_nan is False. Returns: A Nx2 array containing x and y coordinates of each point - as the rows of the array and N is the number of nodes in the skeleton. - The order of the rows corresponds to the ordering of the skeleton nodes. - Any skeleton node not defined will have NaNs present. + as the rows of the array and N is the number of nodes in the + skeleton. The order of the rows corresponds to the ordering + of the skeleton nodes. Any skeleton node not defined will + have NaNs present. """ - self.fix_array() + self._fix_array() if full: return self._points if not copy and not invisible_as_nan: - return self._points[['x', 'y']] + return self._points[["x", "y"]] else: - parray = structured_to_unstructured(self._points[['x', 'y']]) + parray = structured_to_unstructured(self._points[["x", "y"]]) if invisible_as_nan: parray[~self._points.visible] = math.nan return parray + @property + def points_array(self) -> np.ndarray: + """ + Nx2 array of x and y for visible points. + + Row in arrow corresponds to order of points in skeleton. + Invisible points will have nans. + + Returns: + ndarray of visible point coordinates. + """ + return self.get_points_array(invisible_as_nan=True) + @property def centroid(self) -> np.ndarray: """Returns instance centroid as (x,y) numpy row vector.""" - points = self.points_array(invisible_as_nan=True) + points = self.points_array centroid = np.nanmedian(points, axis=0) return centroid @property - def frame_idx(self) -> Union[None, int]: + def frame_idx(self) -> Optional[int]: """ - Get the index of the frame that this instance was found on. This is a convenience - method for Instance.frame.frame_idx. + Get the index of the frame that this instance was found on. + + This is a convenience method for Instance.frame.frame_idx. Returns: - The frame number this instance was found on. + The frame number this instance was found on, or None if the + instance is not associated with frame. """ if self.frame is None: return None @@ -632,12 +824,12 @@ def frame_idx(self) -> Union[None, int]: @attr.s(cmp=False, slots=True) class PredictedInstance(Instance): """ - A predicted instance is an output of the inference procedure. It is - the main output of the inference procedure. + A predicted instance is an output of the inference procedure. Args: - score: The instance level prediction score. + score: The instance-level prediction score. """ + score: float = attr.ib(default=0.0, converter=float) # The underlying Point array type that this instances point array should be. @@ -650,12 +842,13 @@ def __attrs_post_init__(self): raise ValueError("PredictedInstance should not have from_predicted.") @classmethod - def from_instance(cls, instance: Instance, score): + def from_instance(cls, instance: Instance, score: float): """ - Create a PredictedInstance from and Instance object. The fields are - copied in a shallow manner with the exception of points. For each - point in the instance an PredictedPoint is created with score set - to default value. + Create a :class:`PredictedInstance` from an :class:`Instance`. + + The fields are copied in a shallow manner with the exception of + points. For each point in the instance a :class:`PredictedPoint` + is created with score set to default value. Args: instance: The Instance object to shallow copy data from. @@ -664,19 +857,27 @@ def from_instance(cls, instance: Instance, score): Returns: A PredictedInstance for the given Instance. """ - kw_args = attr.asdict(instance, recurse=False, filter=lambda attr, value: attr.name not in ("_points", "_nodes")) - kw_args['points'] = PredictedPointArray.from_array(instance._points) - kw_args['score'] = score + kw_args = attr.asdict( + instance, + recurse=False, + filter=lambda attr, value: attr.name not in ("_points", "_nodes"), + ) + kw_args["points"] = PredictedPointArray.from_array(instance._points) + kw_args["score"] = score return cls(**kw_args) -def make_instance_cattr(): +def make_instance_cattr() -> cattr.Converter: """ - Create a cattr converter for handling Lists of Instances/PredictedInstances + Create a cattr converter for Lists of Instances/PredictedInstances. + + This is required because cattrs doesn't automatically detect the + class when the attributes of one class are a subset of another. Returns: - A cattr converter with hooks registered for structuring and unstructuring - Instances. + A cattr converter with hooks registered for structuring and + unstructuring :class:`Instance` objects and + :class:`PredictedInstance`s. """ converter = cattr.Converter() @@ -689,15 +890,18 @@ def make_instance_cattr(): converter.register_unstructure_hook(PointArray, lambda x: None) converter.register_unstructure_hook(PredictedPointArray, lambda x: None) + def unstructure_instance(x: Instance): - # Unstructure everything but the points array and frame attribute - d = {field.name: converter.unstructure(x.__getattribute__(field.name)) - for field in attr.fields(x.__class__) - if field.name not in ['_points', 'frame']} + # Unstructure everything but the points array, nodes, and frame attribute + d = { + field.name: converter.unstructure(x.__getattribute__(field.name)) + for field in attr.fields(x.__class__) + if field.name not in ["_points", "_nodes", "frame"] + } # Replace the point array with a dict - d['_points'] = converter.unstructure({k: v for k, v in x.nodes_points}) + d["_points"] = converter.unstructure({k: v for k, v in x.nodes_points}) return d @@ -707,17 +911,18 @@ def unstructure_instance(x: Instance): ## STRUCTURE HOOKS def structure_points(x, type): - if 'score' in x.keys(): + if "score" in x.keys(): return cattr.structure(x, PredictedPoint) else: return cattr.structure(x, Point) converter.register_structure_hook(Union[Point, PredictedPoint], structure_points) + # Function to determine object type for objects being structured. def structure_instances_list(x, type): inst_list = [] for inst_data in x: - if 'score' in inst_data.keys(): + if "score" in inst_data.keys(): inst = converter.structure(inst_data, PredictedInstance) else: inst = converter.structure(inst_data, Instance) @@ -725,11 +930,14 @@ def structure_instances_list(x, type): return inst_list - converter.register_structure_hook(Union[List[Instance], List[PredictedInstance]], - structure_instances_list) + converter.register_structure_hook( + Union[List[Instance], List[PredictedInstance]], structure_instances_list + ) - converter.register_structure_hook(ForwardRef('PredictedInstance'), - lambda x, type: converter.structure(x, PredictedInstance)) + converter.register_structure_hook( + ForwardRef("PredictedInstance"), + lambda x, type: converter.structure(x, PredictedInstance), + ) # We can register structure hooks for point arrays that do nothing # because Instance can have a dict of points passed to it in place of @@ -737,7 +945,7 @@ def structure_instances_list(x, type): def structure_point_array(x, t): if x: point1 = x[list(x.keys())[0]] - if 'score' in point1.keys(): + if "score" in point1.keys(): return converter.structure(x, Dict[Node, PredictedPoint]) else: return converter.structure(x, Dict[Node, Point]) @@ -752,26 +960,46 @@ def structure_point_array(x, t): @attr.s(auto_attribs=True) class LabeledFrame: + """ + Holds labeled data for a single frame of a video. + + Args: + video: The :class:`Video` associated with this frame. + frame_idx: The index of frame in video. + """ + video: Video = attr.ib() frame_idx: int = attr.ib(converter=int) - _instances: Union[List[Instance], List[PredictedInstance]] = attr.ib(default=attr.Factory(list)) + _instances: Union[List[Instance], List[PredictedInstance]] = attr.ib( + default=attr.Factory(list) + ) def __attrs_post_init__(self): + """ + Called by attrs. + + Updates :attribute:`Instance.frame` for each instance associated + with this :class:`LabeledFrame`. + """ # Make sure all instances have a reference to this frame for instance in self.instances: instance.frame = self - def __len__(self): + def __len__(self) -> int: + """Returns number of instances associated with frame.""" return len(self.instances) - def __getitem__(self, index): + def __getitem__(self, index) -> Instance: + """Returns instance (retrieved by index).""" return self.instances.__getitem__(index) - def index(self, value: Instance): + def index(self, value: Instance) -> int: + """Returns index of given :class:`Instance`.""" return self.instances.index(value) def __delitem__(self, index): + """Removes instance (by index) from frame.""" value = self.instances.__getitem__(index) self.instances.__delitem__(index) @@ -779,37 +1007,78 @@ def __delitem__(self, index): # Modify the instance to remove reference to this frame value.frame = None - def insert(self, index, value: Instance): + def insert(self, index: int, value: Instance): + """ + Adds instance to frame. + + Args: + index: The index in list of frame instances where we should + insert the new instance. + value: The instance to associate with frame. + + Returns: + None. + """ self.instances.insert(index, value) # Modify the instance to have a reference back to this frame value.frame = self def __setitem__(self, index, value: Instance): + """ + Sets nth instance in frame to the given instance. + + Args: + index: The index of instance to replace with new instance. + value: The new instance to associate with frame. + + Returns: + None. + """ self.instances.__setitem__(index, value) # Modify the instance to have a reference back to this frame value.frame = self - @property - def instances(self): + def find( + self, track: Optional[Union[Track, int]] = -1, user: bool = False + ) -> List[Instance]: """ - A list of instances to associated with this frame. + Retrieves instances (if any) matching specifications. + + Args: + track: The :class:`Track` to match. Note that None will only + match instances where :attribute:`Instance.track` is + None. If track is -1, then we'll match any track. + user: Whether to only match user (non-predicted) instances. Returns: - A list of instances to associated with this frame. + List of instances. """ + instances = self.instances + if user: + instances = list(filter(lambda inst: type(inst) == Instance, instances)) + if track != -1: # use -1 since we want to accept None as possible value + instances = list(filter(lambda inst: inst.track == track, instances)) + return instances + + @property + def instances(self) -> List[Instance]: + """Returns list of all instances associated with this frame.""" return self._instances @instances.setter def instances(self, instances: List[Instance]): """ - Set the list of instances assigned to this frame. Note: whenever an instance - is associated with a LabeledFrame that Instance objects frame property will - be overwritten to the LabeledFrame. + Sets the list of instances associated with this frame. + + Updates the `frame` attribute on each instance to the + :class:`LabeledFrame` which will contain the instance. + The list of instances replaces instances that were previously + associated with frame. Args: - instances: A list of instances to associated with this frame. + instances: A list of instances associated with this frame. Returns: None @@ -822,52 +1091,103 @@ def instances(self, instances: List[Instance]): self._instances = instances @property - def user_instances(self): - return [inst for inst in self._instances if type(inst) == Instance] + def user_instances(self) -> List[Instance]: + """Returns list of user instances associated with this frame.""" + return [ + inst for inst in self._instances if not isinstance(inst, PredictedInstance) + ] @property - def has_user_instances(self): - return (len(self.user_instances) > 0) + def predicted_instances(self) -> List[PredictedInstance]: + """Returns list of predicted instances associated with frame.""" + return [inst for inst in self._instances if isinstance(inst, PredictedInstance)] @property - def unused_predictions(self): + def has_user_instances(self) -> bool: + """Whether the frame contains any user instances.""" + return len(self.user_instances) > 0 + + @property + def unused_predictions(self) -> List[Instance]: + """ + Returns list of "unused" :class:`PredictedInstance` objects in frame. + + This is all the :class:`PredictedInstance` objects which do not have + a corresponding :class:`Instance` in the same track in frame. + """ unused_predictions = [] any_tracks = [inst.track for inst in self._instances if inst.track is not None] if len(any_tracks): # use tracks to determine which predicted instances have been used - used_tracks = [inst.track for inst in self._instances - if type(inst) == Instance and inst.track is not None - ] - unused_predictions = [inst for inst in self._instances - if inst.track not in used_tracks - and type(inst) == PredictedInstance - ] + used_tracks = [ + inst.track + for inst in self._instances + if type(inst) == Instance and inst.track is not None + ] + unused_predictions = [ + inst + for inst in self._instances + if inst.track not in used_tracks and type(inst) == PredictedInstance + ] else: # use from_predicted to determine which predicted instances have been used # TODO: should we always do this instead of using tracks? - used_instances = [inst.from_predicted for inst in self._instances - if inst.from_predicted is not None] - unused_predictions = [inst for inst in self._instances - if type(inst) == PredictedInstance - and inst not in used_instances] + used_instances = [ + inst.from_predicted + for inst in self._instances + if inst.from_predicted is not None + ] + unused_predictions = [ + inst + for inst in self._instances + if type(inst) == PredictedInstance and inst not in used_instances + ] return unused_predictions @property - def instances_to_show(self): + def instances_to_show(self) -> List[Instance]: """ - Return a list of instances associated with this frame, but excluding any - predicted instances for which there's a corresponding regular instance. + Return a list of instances to show in GUI for this frame. + + This list will not include any predicted instances for which + there's a corresponding regular instance. + + Returns: + List of instances to show in GUI. """ unused_predictions = self.unused_predictions - inst_to_show = [inst for inst in self._instances - if type(inst) == Instance or inst in unused_predictions] - inst_to_show.sort(key=lambda inst: inst.track.spawned_on if inst.track is not None else math.inf) + inst_to_show = [ + inst + for inst in self._instances + if type(inst) == Instance or inst in unused_predictions + ] + inst_to_show.sort( + key=lambda inst: inst.track.spawned_on + if inst.track is not None + else math.inf + ) return inst_to_show @staticmethod - def merge_frames(labeled_frames, video): + def merge_frames( + labeled_frames: List["LabeledFrame"], video: "Video", remove_redundant=True + ) -> List["LabeledFrame"]: + """Merged LabeledFrames for same video and frame index. + + Args: + labeled_frames: List of :class:`LabeledFrame` objects to merge. + video: The :class:`Video` for which to merge. + This is specified so we don't have to check all frames when we + already know which video has new labeled frames. + remove_redundant: Whether to drop instances in the merged frames + where there's a perfect match. + + Returns: + The merged list of :class:`LabeledFrame`s. + """ + redundant_count = 0 frames_found = dict() # move instances into first frame with matching frame_idx for idx, lf in enumerate(labeled_frames): @@ -875,13 +1195,174 @@ def merge_frames(labeled_frames, video): if lf.frame_idx in frames_found.keys(): # move instances dst_idx = frames_found[lf.frame_idx] - labeled_frames[dst_idx].instances.extend(lf.instances) + if remove_redundant: + for new_inst in lf.instances: + redundant = False + for old_inst in labeled_frames[dst_idx].instances: + if new_inst.matches(old_inst): + redundant = True + if not hasattr(new_inst, "score"): + redundant_count += 1 + break + if not redundant: + labeled_frames[dst_idx].instances.append(new_inst) + else: + labeled_frames[dst_idx].instances.extend(lf.instances) lf.instances = [] else: # note first lf with this frame_idx frames_found[lf.frame_idx] = idx # remove labeled frames with no instances - labeled_frames = list(filter(lambda lf: len(lf.instances), - labeled_frames)) + labeled_frames = list(filter(lambda lf: len(lf.instances), labeled_frames)) + if redundant_count: + print(f"skipped {redundant_count} redundant instances") return labeled_frames + @classmethod + def complex_merge_between( + cls, base_labels: "Labels", new_frames: List["LabeledFrame"] + ) -> Tuple[Dict[Video, Dict[int, List[Instance]]], List[Instance], List[Instance]]: + """ + Merge data from new frames into a :class:`Labels` object. + + Everything that can be merged cleanly is merged, any conflicts + are returned. + + Args: + base_labels: The :class:`Labels` into which we are merging. + new_frames: The list of :class:`LabeledFrame` objects from + which we are merging. + Returns: + tuple of three items: + * Dictionary, keys are :class:`Video`, values are + dictionary in which keys are frame index (int) + and value is list of :class:`Instance`s + * list of conflicting :class:`Instance` objects from base + * list of conflicting :class:`Instance` objects from new frames + """ + merged = dict() + extra_base = [] + extra_new = [] + + for new_frame in new_frames: + base_lfs = base_labels.find(new_frame.video, new_frame.frame_idx) + merged_instances = None + + # If the base doesn't have a frame corresponding this new + # frame, then it can be merged cleanly. + if not base_lfs: + base_labels.labeled_frames.append(new_frame) + merged_instances = new_frame.instances + else: + # There's a corresponding frame in the base labels, + # so try merging the data. + merged_instances, extra_base_frame, extra_new_frame = cls.complex_frame_merge( + base_lfs[0], new_frame + ) + if extra_base_frame: + extra_base.append(extra_base_frame) + if extra_new_frame: + extra_new.append(extra_new_frame) + + if merged_instances: + if new_frame.video not in merged: + merged[new_frame.video] = dict() + merged[new_frame.video][new_frame.frame_idx] = merged_instances + return merged, extra_base, extra_new + + @classmethod + def complex_frame_merge( + cls, base_frame: "LabeledFrame", new_frame: "LabeledFrame" + ) -> Tuple[List[Instance], List[Instance], List[Instance]]: + """ + Merge two frames, return conflicts if any. + + A conflict occurs when + * each frame has Instances which don't perfectly match those + in the other frame, or + * each frame has PredictedInstances which don't perfectly match + those in the other frame. + + Args: + base_frame: The `LabeledFrame` into which we want to merge. + new_frame: The `LabeledFrame` from which we want to merge. + + Returns: + tuple of three items: + * list of instances that were merged + * list of conflicting instances from base + * list of conflicting instances from new + """ + merged_instances = [] + redundant_instances = [] + extra_base_instances = copy(base_frame.instances) + extra_new_instances = [] + + for new_inst in new_frame: + redundant = False + for base_inst in base_frame.instances: + if new_inst.matches(base_inst): + base_inst.frame = None + extra_base_instances.remove(base_inst) + redundant_instances.append(base_inst) + redundant = True + continue + if not redundant: + new_inst.frame = None + extra_new_instances.append(new_inst) + + conflict = False + if extra_base_instances and extra_new_instances: + base_predictions = list( + filter(lambda inst: hasattr(inst, "score"), extra_base_instances) + ) + new_predictions = list( + filter(lambda inst: hasattr(inst, "score"), extra_new_instances) + ) + + base_has_nonpred = len(extra_base_instances) - len(base_predictions) + new_has_nonpred = len(extra_new_instances) - len(new_predictions) + + # If they both have some predictions or they both have some + # non-predictions, then there is a conflict. + # (Otherwise it's not a conflict since we can cleanly merge + # all the predicted instances with all the non-predicted.) + if base_predictions and new_predictions: + conflict = True + elif base_has_nonpred and new_has_nonpred: + conflict = True + + if conflict: + # Conflict, so update base to just include non-conflicting + # instances (perfect matches) + base_frame.instances.clear() + base_frame.instances.extend(redundant_instances) + else: + # No conflict, so include all instances in base + base_frame.instances.extend(extra_new_instances) + merged_instances = copy(extra_new_instances) + extra_base_instances = [] + extra_new_instances = [] + + # Construct frames to hold any conflicting instances + extra_base = ( + cls( + video=base_frame.video, + frame_idx=base_frame.frame_idx, + instances=extra_base_instances, + ) + if extra_base_instances + else None + ) + + extra_new = ( + cls( + video=new_frame.video, + frame_idx=new_frame.frame_idx, + instances=extra_new_instances, + ) + if extra_new_instances + else None + ) + + return merged_instances, extra_base, extra_new diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 42686a268..4b68e61c3 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -1,23 +1,18 @@ -"""A LEAP Dataset represents annotated (labeled) video data. - -A LEAP Dataset stores almost all data required for training of a model. -This includes, raw video frame data, labelled instances of skeleton _points, -confidence maps, part affinity fields, and skeleton data. A LEAP :class:`.Dataset` -is a high level API to these data structures that abstracts away their underlying -storage format. +""" +A SLEAP dataset collects labeled video frames. +This contains labeled frame data (user annotations and/or predictions), +together with all the other data that is saved for a SLEAP project +(videos, skeletons, negative training sample anchors, etc.). """ import os import re import zipfile import atexit -import glob import attr import cattr -import json -import rapidjson import shutil import tempfile import numpy as np @@ -25,7 +20,7 @@ import h5py as h5 from collections import MutableSequence -from typing import List, Union, Dict, Optional, Tuple +from typing import Callable, List, Union, Dict, Optional try: from typing import ForwardRef @@ -35,62 +30,57 @@ import pandas as pd from sleap.skeleton import Skeleton, Node -from sleap.instance import Instance, Point, LabeledFrame, \ - Track, PredictedPoint, PredictedInstance, \ - make_instance_cattr, PointArray, PredictedPointArray -from sleap.rangelist import RangeList +from sleap.instance import ( + Instance, + Point, + LabeledFrame, + Track, + PredictedPoint, + PredictedInstance, + make_instance_cattr, + PointArray, + PredictedPointArray, +) + +from sleap.io.legacy import load_labels_json_old from sleap.io.video import Video -from sleap.util import uniquify - - -def json_loads(json_str: str): - try: - return rapidjson.loads(json_str) - except: - return json.loads(json_str) - -def json_dumps(d: Dict, filename: str = None): - """ - A simple wrapper around the JSON encoder we are using. - - Args: - d: The dict to write. - f: The filename to write to. - - Returns: - None - """ - import codecs - encoder = rapidjson +from sleap.rangelist import RangeList +from sleap.util import uniquify, weak_filename_match, json_dumps, json_loads - if filename: - with open(filename, 'w') as f: - encoder.dump(d, f, ensure_ascii=False) - else: - return encoder.dumps(d) """ The version number to put in the Labels JSON format. """ LABELS_JSON_FILE_VERSION = "2.0.0" + @attr.s(auto_attribs=True) class Labels(MutableSequence): """ - The LEAP :class:`.Labels` class represents an API for accessing labeled video - frames and other associated metadata. This class is front-end for all - interactions with loading, writing, and modifying these labels. The actual - storage backend for the data is mostly abstracted away from the main - interface. - - Args: - labeled_frames: A list of `LabeledFrame`s - videos: A list of videos that these labels may or may not reference. - That is, every LabeledFrame's video will be in videos but a Video - object from videos might not have any LabeledFrame. - skeletons: A list of skeletons that these labels may or may not reference. - tracks: A list of tracks that instances can belong to. - suggestions: A dict with a list for each video of suggested frames to label. + The :class:`Labels` class collects the data for a SLEAP project. + + This class is front-end for all interactions with loading, writing, + and modifying these labels. The actual storage backend for the data + is mostly abstracted away from the main interface. + + Attributes: + labeled_frames: A list of :class:`LabeledFrame` objects + videos: A list of :class:`Video` objects that these labels may or may + not reference. The video for every `LabeledFrame` will be + stored in `videos` attribute, but some videos in + this list may not have any associated labeled frames. + skeletons: A list of :class:`Skeleton` objects (again, that may or may + not be referenced by an :class:`Instance` in labeled frame). + tracks: A list of :class:`Track` that instances can belong to. + suggestions: Dictionary that stores "suggested" frames for + videos in project. These can be suggested frames for user + to label or suggested frames for user to review. + Dictionary key is :class:`Video`, value is list of frame + indices. + negative_anchors: Dictionary that stores center-points around + which to crop as negative samples when training. + Dictionary key is :class:`Video`, value is list of + (frame index, x, y) tuples. """ labeled_frames: List[LabeledFrame] = attr.ib(default=attr.Factory(list)) @@ -102,73 +92,134 @@ class Labels(MutableSequence): negative_anchors: Dict[Video, list] = attr.ib(default=attr.Factory(dict)) def __attrs_post_init__(self): + """ + Called by attrs after the class is instantiated. + + This updates the top level contains (videos, skeletons, etc) + from data in the labeled frames, as well as various caches. + """ # Add any videos/skeletons/nodes/tracks that are in labeled # frames but not in the lists on our object self._update_from_labels() # Update caches used to find frames by frame index - self._update_lookup_cache() + self._build_lookup_caches() # Create a variable to store a temporary storage directory # used when we unzip self.__temp_dir = None - def _update_from_labels(self, merge=False): - """Update top level attributes with data from labeled frames. + def _update_from_labels(self, merge: bool = False): + """Updates top level attributes with data from labeled frames. Args: - merge: if True, then update even if there's already data + merge: If True, then update even if there's already data. + + Returns: + None. """ # Add any videos that are present in the labels but # missing from the video list if merge or len(self.videos) == 0: - self.videos = list(set(self.videos).union({label.video for label in self.labels})) + # find videos in labeled frames that aren't yet in top level videos + new_videos = {label.video for label in self.labels} - set(self.videos) + # just add the new videos so we don't re-order current list + if len(new_videos): + self.videos.extend(list(new_videos)) # Ditto for skeletons if merge or len(self.skeletons) == 0: - self.skeletons = list(set(self.skeletons).union( - {instance.skeleton - for label in self.labels - for instance in label.instances})) + self.skeletons = list( + set(self.skeletons).union( + { + instance.skeleton + for label in self.labels + for instance in label.instances + } + ) + ) # Ditto for nodes if merge or len(self.nodes) == 0: - self.nodes = list(set(self.nodes).union({node for skeleton in self.skeletons for node in skeleton.nodes})) + self.nodes = list( + set(self.nodes).union( + {node for skeleton in self.skeletons for node in skeleton.nodes} + ) + ) # Ditto for tracks, a pattern is emerging here if merge or len(self.tracks) == 0: - tracks = set(self.tracks) - - # Add tracks from any Instances or PredictedInstances - tracks = tracks.union({instance.track - for frame in self.labels - for instance in frame.instances - if instance.track}) + # Get tracks from any Instances or PredictedInstances + other_tracks = { + instance.track + for frame in self.labels + for instance in frame.instances + if instance.track + } # Add tracks from any PredictedInstance referenced by instance # This fixes things when there's a referenced PredictionInstance # which is no longer in the frame. - tracks = tracks.union({instance.from_predicted.track - for frame in self.labels - for instance in frame.instances - if instance.from_predicted - and instance.from_predicted.track}) + other_tracks = other_tracks.union( + { + instance.from_predicted.track + for frame in self.labels + for instance in frame.instances + if instance.from_predicted and instance.from_predicted.track + } + ) + + # Get list of other tracks not already in track list + new_tracks = list(other_tracks - set(self.tracks)) + + # Sort the new tracks by spawned on and then name + new_tracks.sort(key=lambda t: (t.spawned_on, t.name)) + + self.tracks.extend(new_tracks) + + def _update_containers(self, new_label: LabeledFrame): + """ Ensure that top-level containers are kept updated with new + instances of objects that come along with new labels. """ + + if new_label.video not in self.videos: + self.videos.append(new_label.video) + + for skeleton in {instance.skeleton for instance in new_label}: + if skeleton not in self.skeletons: + self.skeletons.append(skeleton) + for node in skeleton.nodes: + if node not in self.nodes: + self.nodes.append(node) + + # Add any new Tracks as well + for instance in new_label.instances: + if instance.track and instance.track not in self.tracks: + self.tracks.append(instance.track) - self.tracks = list(tracks) + # Sort the tracks again + self.tracks.sort(key=lambda t: (t.spawned_on, t.name)) - # Sort the tracks by spawned on and then name - self.tracks.sort(key=lambda t:(t.spawned_on, t.name)) + # Update cache datastructures + if new_label.video not in self._lf_by_video: + self._lf_by_video[new_label.video] = [] + if new_label.video not in self._frame_idx_map: + self._frame_idx_map[new_label.video] = dict() + self._lf_by_video[new_label.video].append(new_label) + self._frame_idx_map[new_label.video][new_label.frame_idx] = new_label - def _update_lookup_cache(self): + def _build_lookup_caches(self): + """Builds (or rebuilds) various caches.""" # Data structures for caching self._lf_by_video = dict() self._frame_idx_map = dict() self._track_occupancy = dict() for video in self.videos: self._lf_by_video[video] = [lf for lf in self.labels if lf.video == video] - self._frame_idx_map[video] = {lf.frame_idx: lf for lf in self._lf_by_video[video]} + self._frame_idx_map[video] = { + lf.frame_idx: lf for lf in self._lf_by_video[video] + } self._track_occupancy[video] = self._make_track_occupany(video) # Below are convenience methods for working with Labels as list. @@ -179,20 +230,30 @@ def _update_lookup_cache(self): @property def labels(self): - """ Alias for labeled_frames """ + """Alias for labeled_frames.""" return self.labeled_frames - @property - def user_labeled_frames(self): - return [lf for lf in self.labeled_frames if lf.has_user_instances] - - def __len__(self): + def __len__(self) -> int: + """Returns number of labeled frames.""" return len(self.labeled_frames) - def index(self, value): + def index(self, value) -> int: + """Returns index of labeled frame in list of labeled frames.""" return self.labeled_frames.index(value) - def __contains__(self, item): + def __contains__(self, item) -> bool: + """ + Checks if object contains the given item. + + Args: + item: The item to look for within `Labels`. + This can be :class:`LabeledFrame`, + :class:`Video`, :class:`Skeleton`, + :class:`Node`, or (:class:`Video`, frame idx) tuple. + + Returns: + True if item is found. + """ if isinstance(item, LabeledFrame): return item in self.labeled_frames elif isinstance(item, Video): @@ -201,10 +262,26 @@ def __contains__(self, item): return item in self.skeletons elif isinstance(item, Node): return item in self.nodes - elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], Video) and isinstance(item[1], int): + elif ( + isinstance(item, tuple) + and len(item) == 2 + and isinstance(item[0], Video) + and isinstance(item[1], int) + ): return self.find_first(*item) is not None - def __getitem__(self, key): + def __getitem__(self, key) -> List[LabeledFrame]: + """Returns labeled frames matching key. + + Args: + key: `Video` or (`Video`, frame index) to match against. + + Raises: + KeyError: If labeled frame for `Video` or frame index + cannot be found. + + Returns: A list with the matching labeled frame(s). + """ if isinstance(key, int): return self.labels.__getitem__(key) @@ -213,7 +290,12 @@ def __getitem__(self, key): raise KeyError("Video not found in labels.") return self.find(video=key) - elif isinstance(key, tuple) and len(key) == 2 and isinstance(key[0], Video) and isinstance(key[1], int): + elif ( + isinstance(key, tuple) + and len(key) == 2 + and isinstance(key[0], Video) + and isinstance(key[1], int) + ): if key[0] not in self.videos: raise KeyError("Video not found in labels.") @@ -227,45 +309,109 @@ def __getitem__(self, key): else: raise KeyError("Invalid label indexing arguments.") - def find(self, video: Video, frame_idx: Union[int, range] = None, return_new: bool=False) -> List[LabeledFrame]: + def __setitem__(self, index, value: LabeledFrame): + """Sets labeled frame at given index.""" + # TODO: Maybe we should remove this method altogether? + self.labeled_frames.__setitem__(index, value) + self._update_containers(value) + + def insert(self, index, value: LabeledFrame): + """Inserts labeled frame at given index.""" + if value in self or (value.video, value.frame_idx) in self: + return + + self.labeled_frames.insert(index, value) + self._update_containers(value) + + def append(self, value: LabeledFrame): + """Adds labeled frame to list of labeled frames.""" + self.insert(len(self) + 1, value) + + def __delitem__(self, key): + """Removes labeled frame with given index.""" + self.labeled_frames.remove(self.labeled_frames[key]) + + def remove(self, value: LabeledFrame): + """Removes given labeled frame.""" + self.labeled_frames.remove(value) + self._lf_by_video[value.video].remove(value) + del self._frame_idx_map[value.video][value.frame_idx] + + def find( + self, + video: Video, + frame_idx: Optional[Union[int, range]] = None, + return_new: bool = False, + ) -> List[LabeledFrame]: """ Search for labeled frames given video and/or frame index. Args: - video: a `Video` instance that is associated with the labeled frames - frame_idx: an integer specifying the frame index within the video - return_new: return singleton of new `LabeledFrame` if none found? + video: A :class:`Video` that is associated with the project. + frame_idx: The frame index (or indices) which we want to + find in the video. If a range is specified, we'll return + all frames with indices in that range. If not specific, + then we'll return all labeled frames for video. + return_new: Whether to return singleton of new and empty + :class:`LabeledFrame` if none is found in project. Returns: - List of `LabeledFrame`s that match the criteria. Empty if no matches found. - + List of `LabeledFrame` objects that match the criteria. + Empty if no matches found, unless return_new is True, + in which case it contains a new `LabeledFrame` with + `video` and `frame_index` set. """ - null_result = [LabeledFrame(video=video, frame_idx=frame_idx)] if return_new else [] + null_result = ( + [LabeledFrame(video=video, frame_idx=frame_idx)] if return_new else [] + ) if frame_idx is not None: - if video not in self._frame_idx_map: return null_result + if video not in self._frame_idx_map: + return null_result if type(frame_idx) == range: - return [self._frame_idx_map[video][idx] for idx in frame_idx if idx in self._frame_idx_map[video]] + return [ + self._frame_idx_map[video][idx] + for idx in frame_idx + if idx in self._frame_idx_map[video] + ] - if frame_idx not in self._frame_idx_map[video]: return null_result + if frame_idx not in self._frame_idx_map[video]: + return null_result return [self._frame_idx_map[video][frame_idx]] else: - if video not in self._lf_by_video: return null_result + if video not in self._lf_by_video: + return null_result return self._lf_by_video[video] def frames(self, video: Video, from_frame_idx: int = -1, reverse=False): """ - Iterator over all frames in a video, starting with first frame - after specified frame_idx (or first frame in video if none specified). + Iterator over all labeled frames in a video. + + Args: + video: A :class:`Video` that is associated with the project. + from_frame_idx: The frame index from which we want to start. + Defaults to the first frame of video. + reverse: Whether to iterate over frames in reverse order. + + Yields: + :class:`LabeledFrame` """ - if video not in self._frame_idx_map: return None + if video not in self._frame_idx_map: + return None # Get sorted list of frame indexes for this video - frame_idxs = sorted(self._frame_idx_map[video].keys(), reverse=reverse) + frame_idxs = sorted(self._frame_idx_map[video].keys()) - # Find the next frame index after the specified frame - next_frame_idx = min(filter(lambda x: x > from_frame_idx, frame_idxs), default=frame_idxs[0]) + # Find the next frame index after (before) the specified frame + if not reverse: + next_frame_idx = min( + filter(lambda x: x > from_frame_idx, frame_idxs), default=frame_idxs[0] + ) + else: + next_frame_idx = max( + filter(lambda x: x < from_frame_idx, frame_idxs), default=frame_idxs[-1] + ) cut_list_idx = frame_idxs.index(next_frame_idx) # Shift list of frame indices to start with specified frame @@ -275,66 +421,175 @@ def frames(self, video: Video, from_frame_idx: int = -1, reverse=False): for idx in frame_idxs: yield self._frame_idx_map[video][idx] - def find_first(self, video: Video, frame_idx: int = None) -> LabeledFrame: - """ Find the first occurrence of a labeled frame for the given video and/or frame index. + def find_first( + self, video: Video, frame_idx: Optional[int] = None + ) -> Optional[LabeledFrame]: + """ + Finds the first occurrence of a matching labeled frame. + + Matches on frames for the given video and/or frame index. Args: - video: a `Video` instance that is associated with the labeled frames - frame_idx: an integer specifying the frame index within the video + video: a `Video` instance that is associated with the + labeled frames + frame_idx: an integer specifying the frame index within + the video Returns: - First `LabeledFrame` that match the criteria or None if none were found. + First `LabeledFrame` that match the criteria + or None if none were found. """ if video in self.videos: for label in self.labels: - if label.video == video and (frame_idx is None or (label.frame_idx == frame_idx)): + if label.video == video and ( + frame_idx is None or (label.frame_idx == frame_idx) + ): return label - def find_last(self, video: Video, frame_idx: int = None) -> LabeledFrame: - """ Find the last occurrence of a labeled frame for the given video and/or frame index. + def find_last( + self, video: Video, frame_idx: Optional[int] = None + ) -> Optional[LabeledFrame]: + """ + Finds the last occurrence of a matching labeled frame. + + Matches on frames for the given video and/or frame index. Args: - video: A `Video` instance that is associated with the labeled frames - frame_idx: An integer specifying the frame index within the video + video: a `Video` instance that is associated with the + labeled frames + frame_idx: an integer specifying the frame index within + the video Returns: - LabeledFrame: Last label that matches the criteria or None if no results. + Last `LabeledFrame` that match the criteria + or None if none were found. """ if video in self.videos: for label in reversed(self.labels): - if label.video == video and (frame_idx is None or (label.frame_idx == frame_idx)): + if label.video == video and ( + frame_idx is None or (label.frame_idx == frame_idx) + ): return label + @property + def user_labeled_frames(self): + """ + Returns all labeled frames with user (non-predicted) instances. + """ + return [lf for lf in self.labeled_frames if lf.has_user_instances] + + def get_video_user_labeled_frames(self, video: Video) -> List[LabeledFrame]: + """ + Returns labeled frames for given video with user instances. + """ + return [ + lf + for lf in self.labeled_frames + if lf.has_user_instances and lf.video == video + ] + + # Methods for instances + def instance_count(self, video: Video, frame_idx: int) -> int: + """Returns number of instances matching video/frame index.""" count = 0 labeled_frame = self.find_first(video, frame_idx) if labeled_frame is not None: - count = len([inst for inst in labeled_frame.instances if type(inst)==Instance]) + count = len( + [inst for inst in labeled_frame.instances if type(inst) == Instance] + ) return count - def get_track_occupany(self, video: Video): + @property + def all_instances(self): + """Returns list of all instances.""" + return list(self.instances()) + + @property + def user_instances(self): + """Returns list of all user (non-predicted) instances.""" + return [inst for inst in self.all_instances if type(inst) == Instance] + + def instances(self, video: Video = None, skeleton: Skeleton = None): + """ + Iterate over instances in the labels, optionally with filters. + + Args: + video: Only iterate through instances in this video + skeleton: Only iterate through instances with this skeleton + + Yields: + Instance: The next labeled instance + """ + for label in self.labels: + if video is None or label.video == video: + for instance in label.instances: + if skeleton is None or instance.skeleton == skeleton: + yield instance + + # Methods for tracks + + def get_track_occupany(self, video: Video) -> List: + """Returns track occupancy list for given video""" try: return self._track_occupancy[video] except: return [] def add_track(self, video: Video, track: Track): + """Adds track to labels, updating occupancy.""" self.tracks.append(track) self._track_occupancy[video][track] = RangeList() - def track_set_instance(self, frame: LabeledFrame, instance: Instance, new_track: Track): - self.track_swap(frame.video, new_track, instance.track, (frame.frame_idx, frame.frame_idx+1)) + def track_set_instance( + self, frame: LabeledFrame, instance: Instance, new_track: Track + ): + """Sets track on given instance, updating occupancy.""" + self.track_swap( + frame.video, + new_track, + instance.track, + (frame.frame_idx, frame.frame_idx + 1), + ) if instance.track is None: self._track_remove_instance(frame, instance) instance.track = new_track - def track_swap(self, video: Video, new_track: Track, old_track: Track, frame_range: tuple): + def track_swap( + self, + video: Video, + new_track: Track, + old_track: Optional[Track], + frame_range: tuple, + ): + """ + Swaps track assignment for instances in two tracks. + + If you need to change the track to or from None, you'll need + to use :meth:`track_set_instance` for each specific + instance you want to modify. + + Args: + video: The :class:`Video` for which we want to swap tracks. + new_track: A :class:`Track` for which we want to swap + instances with another track. + old_track: The other :class:`Track` for swapping. + frame_range: Tuple of (start, end) frame indexes. + If you want to swap tracks on a single frame, use + (frame index, frame index + 1). + Returns: + None. + """ # Get ranges in track occupancy cache - _, within_old, _ = self._track_occupancy[video][old_track].cut_range(frame_range) - _, within_new, _ = self._track_occupancy[video][new_track].cut_range(frame_range) + _, within_old, _ = self._track_occupancy[video][old_track].cut_range( + frame_range + ) + _, within_new, _ = self._track_occupancy[video][new_track].cut_range( + frame_range + ) if old_track is not None: # Instances that didn't already have track can't be handled here. @@ -352,33 +607,42 @@ def track_swap(self, video: Video, new_track: Track, old_track: Track, frame_ran new_track_instances = self.find_track_occupancy(video, new_track, frame_range) # swap new to old tracks on all instances - for frame, instance in old_track_instances: + for instance in old_track_instances: instance.track = new_track # old_track can be `Track` or int # If int, it's index in instance list which we'll use as a pseudo-track, # but we won't set instances currently on new_track to old_track. if type(old_track) == Track: - for frame, instance in new_track_instances: + for instance in new_track_instances: instance.track = old_track def _track_remove_instance(self, frame: LabeledFrame, instance: Instance): - if instance.track not in self._track_occupancy[frame.video]: return + """Manipulates track occupancy cache.""" + if instance.track not in self._track_occupancy[frame.video]: + return # If this is only instance in track in frame, then remove frame from track. - if len(list(filter(lambda inst: inst.track == instance.track, frame.instances))) == 1: - self._track_occupancy[frame.video][instance.track].remove((frame.frame_idx, frame.frame_idx+1)) + if len(frame.find(track=instance.track)) == 1: + self._track_occupancy[frame.video][instance.track].remove( + (frame.frame_idx, frame.frame_idx + 1) + ) def remove_instance(self, frame: LabeledFrame, instance: Instance): + """Removes instance from frame, updating track occupancy.""" self._track_remove_instance(frame, instance) frame.instances.remove(instance) def add_instance(self, frame: LabeledFrame, instance: Instance): + """Adds instance to frame, updating track occupancy.""" if frame.video not in self._track_occupancy: self._track_occupancy[frame.video] = dict() # Ensure that there isn't already an Instance with this track - tracks_in_frame = [inst.track for inst in frame - if type(inst) == Instance and inst.track is not None] + tracks_in_frame = [ + inst.track + for inst in frame + if type(inst) == Instance and inst.track is not None + ] if instance.track in tracks_in_frame: instance.track = None @@ -386,10 +650,13 @@ def add_instance(self, frame: LabeledFrame, instance: Instance): if instance.track not in self._track_occupancy[frame.video]: self._track_occupancy[frame.video][instance.track] = RangeList() - self._track_occupancy[frame.video][instance.track].insert((frame.frame_idx, frame.frame_idx+1)) + self._track_occupancy[frame.video][instance.track].insert( + (frame.frame_idx, frame.frame_idx + 1) + ) frame.instances.append(instance) - def _make_track_occupany(self, video): + def _make_track_occupany(self, video: Video) -> Dict[Video, RangeList]: + """Build cached track occupancy data.""" frame_idx_map = self._frame_idx_map[video] tracks = dict() @@ -402,8 +669,10 @@ def _make_track_occupany(self, video): tracks[instance.track].add(frame_idx) return tracks - def find_track_occupancy(self, video: Video, track: Union[Track, int], frame_range=None) -> List[Tuple[LabeledFrame, Instance]]: - """Get instances for a given track. + def find_track_occupancy( + self, video: Video, track: Union[Track, int], frame_range=None + ) -> List[Instance]: + """Get instances for a given video, track, and range of frames. Args: video: the `Video` @@ -412,7 +681,7 @@ def find_track_occupancy(self, video: Video, track: Union[Track, int], frame_ran If specified, only return instances on frames in range. If None, return all instances for given track. Returns: - list of `Instance` objects + List of :class:`Instance` objects. """ frame_range = range(*frame_range) if type(frame_range) == tuple else frame_range @@ -421,98 +690,99 @@ def does_track_match(inst, tr, labeled_frame): match = False if type(tr) == Track and inst.track is tr: match = True - elif (type(tr) == int and labeled_frame.instances.index(inst) == tr - and inst.track is None): + elif ( + type(tr) == int + and labeled_frame.instances.index(inst) == tr + and inst.track is None + ): match = True return match - track_frame_inst = [(lf, instance) - for lf in self.find(video) - for instance in lf.instances - if does_track_match(instance, track, lf) - and (frame_range is None or lf.frame_idx in frame_range)] + track_frame_inst = [ + instance + for lf in self.find(video) + for instance in lf.instances + if does_track_match(instance, track, lf) + and (frame_range is None or lf.frame_idx in frame_range) + ] return track_frame_inst + # Methods for suggestions - def find_track_instances(self, *args, **kwargs) -> List[Instance]: - return [inst for lf, inst in self.find_track_occupancy(*args, **kwargs)] - - @property - def all_instances(self): - return list(self.instances()) - - @property - def user_instances(self): - return [inst for inst in self.all_instances if type(inst) == Instance] - - def instances(self, video: Video = None, skeleton: Skeleton = None): - """ Iterate through all instances in the labels, optionally with filters. - - Args: - video: Only iterate through instances in this video - skeleton: Only iterate through instances with this skeleton - - Yields: - Instance: The next labeled instance + def get_video_suggestions(self, video: Video) -> list: """ - for label in self.labels: - if video is None or label.video == video: - for instance in label.instances: - if skeleton is None or instance.skeleton == skeleton: - yield instance - - def _update_containers(self, new_label: LabeledFrame): - """ Ensure that top-level containers are kept updated with new - instances of objects that come along with new labels. """ - - if new_label.video not in self.videos: - self.videos.append(new_label.video) - - for skeleton in {instance.skeleton for instance in new_label}: - if skeleton not in self.skeletons: - self.skeletons.append(skeleton) - for node in skeleton.nodes: - if node not in self.nodes: - self.nodes.append(node) - - # Add any new Tracks as well - for instance in new_label.instances: - if instance.track and instance.track not in self.tracks: - self.tracks.append(instance.track) + Returns the list of suggested frames for the specified video + or suggestions for all videos (if no video specified). + """ + return self.suggestions.get(video, list()) - # Sort the tracks again - self.tracks.sort(key=lambda t: (t.spawned_on, t.name)) + def get_suggestions(self) -> list: + """Return all suggestions as a list of (video, frame) tuples.""" + suggestion_list = [ + (video, frame_idx) + for video in self.videos + for frame_idx in self.get_video_suggestions(video) + ] + return suggestion_list - # Update cache datastructures - if new_label.video not in self._lf_by_video: - self._lf_by_video[new_label.video] = [] - if new_label.video not in self._frame_idx_map: - self._frame_idx_map[new_label.video] = dict() - self._lf_by_video[new_label.video].append(new_label) - self._frame_idx_map[new_label.video][new_label.frame_idx] = new_label + def get_next_suggestion(self, video, frame_idx, seek_direction=1) -> list: + """Returns a (video, frame_idx) tuple seeking from given frame.""" + # make sure we have valid seek_direction + if seek_direction not in (-1, 1): + return (None, None) + # make sure the video belongs to this Labels object + if video not in self.videos: + return (None, None) - def __setitem__(self, index, value: LabeledFrame): - # TODO: Maybe we should remove this method altogether? - self.labeled_frames.__setitem__(index, value) - self._update_containers(value) + all_suggestions = self.get_suggestions() - def insert(self, index, value: LabeledFrame): - if value in self or (value.video, value.frame_idx) in self: - return + # If we're currently on a suggestion, then follow order of list + if (video, frame_idx) in all_suggestions: + suggestion_idx = all_suggestions.index((video, frame_idx)) + new_idx = (suggestion_idx + seek_direction) % len(all_suggestions) + video, frame_suggestion = all_suggestions[new_idx] - self.labeled_frames.insert(index, value) - self._update_containers(value) + # Otherwise, find the prev/next suggestion sorted by frame order + else: + # look for next (or previous) suggestion in current video + if seek_direction == 1: + frame_suggestion = min( + (i for i in self.get_video_suggestions(video) if i > frame_idx), + default=None, + ) + else: + frame_suggestion = max( + (i for i in self.get_video_suggestions(video) if i < frame_idx), + default=None, + ) + if frame_suggestion is not None: + return (video, frame_suggestion) + # if we didn't find suggestion in current video, + # then we want earliest frame in next video with suggestions + next_video_idx = (self.videos.index(video) + seek_direction) % len( + self.videos + ) + video = self.videos[next_video_idx] + if seek_direction == 1: + frame_suggestion = min( + (i for i in self.get_video_suggestions(video)), default=None + ) + else: + frame_suggestion = max( + (i for i in self.get_video_suggestions(video)), default=None + ) + return (video, frame_suggestion) - def append(self, value: LabeledFrame): - self.insert(len(self) + 1, value) + def set_suggestions(self, suggestions: Dict[Video, list]): + """Sets the suggested frames.""" + self.suggestions = suggestions - def __delitem__(self, key): - self.labeled_frames.remove(self.labeled_frames[key]) + def delete_suggestions(self, video): + """Deletes suggestions for specified video.""" + if video in self.suggestions: + del self.suggestions[video] - def remove(self, value: LabeledFrame): - self.labeled_frames.remove(value) - self._lf_by_video[new_label.video].remove(value) - del self._frame_idx_map[new_label.video][value.frame_idx] + # Methods for videos def add_video(self, video: Video): """ Add a video to the labels if it is not already in it. @@ -543,8 +813,7 @@ def remove_video(self, video: Video): self.labeled_frames.remove(label) # Delete data that's indexed by video - if video in self.suggestions: - del self.suggestions[video] + self.delete_suggestions(video) if video in self.negative_anchors: del self.negative_anchors[video] @@ -557,7 +826,9 @@ def remove_video(self, video: Video): if video in self._frame_idx_map: del self._frame_idx_map[video] - def add_negative_anchor(self, video:Video, frame_idx: int, where: tuple): + # Methods for negative anchors + + def add_negative_anchor(self, video: Video, frame_idx: int, where: tuple): """Adds a location for a negative training sample. Args: @@ -569,72 +840,58 @@ def add_negative_anchor(self, video:Video, frame_idx: int, where: tuple): self.negative_anchors[video] = [] self.negative_anchors[video].append((frame_idx, *where)) - def get_video_suggestions(self, video:Video) -> list: - """ - Returns the list of suggested frames for the specified video - or suggestions for all videos (if no video specified). - """ - return self.suggestions.get(video, list()) - - def get_suggestions(self) -> list: - """Return all suggestions as a list of (video, frame) tuples.""" - suggestion_list = [(video, frame_idx) - for video in self.videos - for frame_idx in self.get_video_suggestions(video) - ] - return suggestion_list - - def get_next_suggestion(self, video, frame_idx, seek_direction=1) -> list: - """Returns a (video, frame_idx) tuple.""" - # make sure we have valid seek_direction - if seek_direction not in (-1, 1): return (None, None) - # make sure the video belongs to this Labels object - if video not in self.videos: return (None, None) - - all_suggestions = self.get_suggestions() + def remove_negative_anchors(self, video: Video, frame_idx: int): + """Removes negative training samples for given video and frame. - # If we're currently on a suggestion, then follow order of list - if (video, frame_idx) in all_suggestions: - suggestion_idx = all_suggestions.index((video, frame_idx)) - new_idx = (suggestion_idx+seek_direction)%len(all_suggestions) - video, frame_suggestion = all_suggestions[new_idx] + Args: + video: the `Video` for which we're removing negative samples + frame_idx: frame index + Returns: + None + """ + if video not in self.negative_anchors: + return - # Otherwise, find the prev/next suggestion sorted by frame order - else: - # look for next (or previous) suggestion in current video - if seek_direction == 1: - frame_suggestion = min((i for i in self.get_video_suggestions(video) if i > frame_idx), default=None) - else: - frame_suggestion = max((i for i in self.get_video_suggestions(video) if i < frame_idx), default=None) - if frame_suggestion is not None: return (video, frame_suggestion) - # if we didn't find suggestion in current video, - # then we want earliest frame in next video with suggestions - next_video_idx = (self.videos.index(video) + seek_direction) % len(self.videos) - video = self.videos[next_video_idx] - if seek_direction == 1: - frame_suggestion = min((i for i in self.get_video_suggestions(video)), default=None) - else: - frame_suggestion = max((i for i in self.get_video_suggestions(video)), default=None) - return (video, frame_suggestion) + anchors = [ + (idx, x, y) + for idx, x, y in self.negative_anchors[video] + if idx != frame_idx + ] + self.negative_anchors[video] = anchors - def set_suggestions(self, suggestions:Dict[Video, list]): - """Sets the suggested frames.""" - self.suggestions = suggestions + # Methods for saving/loading - def extend_from(self, new_frames): - """Merge data from another Labels object or list of LabeledFrames into self. + def extend_from( + self, new_frames: Union["Labels", List[LabeledFrame]], unify: bool = False + ): + """ + Merge data from another `Labels` object or `LabeledFrame` list. Arg: new_frames: the object from which to copy data + unify: whether to replace objects in new frames with + corresponding objects from current `Labels` data + Returns: bool, True if we added frames, False otherwise """ # allow either Labels or list of LabeledFrames - if isinstance(new_frames, Labels): new_frames = new_frames.labeled_frames + if isinstance(new_frames, Labels): + new_frames = new_frames.labeled_frames # return if this isn't non-empty list of labeled frames - if not isinstance(new_frames, list) or len(new_frames) == 0: return False - if not isinstance(new_frames[0], LabeledFrame): return False + if not isinstance(new_frames, list) or len(new_frames) == 0: + return False + if not isinstance(new_frames[0], LabeledFrame): + return False + + # If unify, we want to replace objects in the frames with + # corresponding objects from the current labels. + # We do this by deserializing/serializing with match_to. + if unify: + new_json = Labels(labeled_frames=new_frames).to_dict() + new_labels = Labels.from_json(new_json, match_to=self) + new_frames = new_labels.labeled_frames # copy the labeled frames self.labeled_frames.extend(new_frames) @@ -644,24 +901,140 @@ def extend_from(self, new_frames): # update top level videos/nodes/skeletons/tracks self._update_from_labels(merge=True) - self._update_lookup_cache() + self._build_lookup_caches() return True - def merge_matching_frames(self, video=None): + @classmethod + def complex_merge_between( + cls, base_labels: "Labels", new_labels: "Labels", unify: bool = True + ) -> tuple: """ - Combine all instances from LabeledFrames that have same frame_idx. + Merge frames and other data from one dataset into another. + + Anything that can be merged cleanly is merged into base_labels. + + Frames conflict just in case each labels object has a matching + frame (same video and frame idx) with instances not in other. + + Frames can be merged cleanly if: + + * the frame is in only one of the labels, or + * the frame is in both labels, but all instances perfectly match + (which means they are redundant), or + * the frame is in both labels, maybe there are some redundant + instances, but only one version of the frame has additional + instances not in the other. Args: - video (optional): combine for this video; if None, do all videos + base_labels: the `Labels` that we're merging into + new_labels: the `Labels` that we're merging from + unify: whether to replace objects (e.g., `Video`) in + new_labels with *matching* objects from base + Returns: - none + tuple of three items: + + * Dictionary, keys are :class:`Video`, values are + dictionary in which keys are frame index (int) + and value is list of :class:`Instance` objects + * list of conflicting :class:`Instance` objects from base + * list of conflicting :class:`Instance` objects from new + + """ + # If unify, we want to replace objects in the frames with + # corresponding objects from the current labels. + # We do this by deserializing/serializing with match_to. + if unify: + new_json = new_labels.to_dict() + new_labels = cls.from_json(new_json, match_to=base_labels) + + # Merge anything that can be merged cleanly and get conflicts + merged, extra_base, extra_new = LabeledFrame.complex_merge_between( + base_labels=base_labels, new_frames=new_labels.labeled_frames + ) + + # For clean merge, finish merge now by cleaning up base object + if not extra_base and not extra_new: + # Add any new videos (etc) into top level lists in base + base_labels._update_from_labels(merge=True) + # Update caches + base_labels._build_lookup_caches() + + # Merge suggestions and negative anchors + cls.merge_container_dicts(base_labels.suggestions, new_labels.suggestions) + cls.merge_container_dicts( + base_labels.negative_anchors, new_labels.negative_anchors + ) + + return merged, extra_base, extra_new + + # @classmethod + # def merge_predictions_by_score(cls, extra_base: List[LabeledFrame], extra_new: List[LabeledFrame]): + # """ + # Remove all predictions from input lists, return list with only + # the merged predictions. + # + # Args: + # extra_base: list of `LabeledFrame` objects + # extra_new: list of `LabeledFrame` objects + # Conflicting frames should have same index in both lists. + # Returns: + # list of `LabeledFrame` objects with merged predictions + # """ + # pass + + @staticmethod + def finish_complex_merge( + base_labels: "Labels", resolved_frames: List[LabeledFrame] + ): + """ + Finish conflicted merge from complex_merge_between. + + Args: + base_labels: the `Labels` that we're merging into + resolved_frames: the list of frames to add into base_labels + Returns: + None. + """ + # Add all the resolved frames to base + base_labels.labeled_frames.extend(resolved_frames) + + # Combine instances when there are two LabeledFrames for same + # video and frame index + base_labels.merge_matching_frames() + + # Add any new videos (etc) into top level lists in base + base_labels._update_from_labels(merge=True) + # Update caches + base_labels._build_lookup_caches() + + @staticmethod + def merge_container_dicts(dict_a: Dict, dict_b: Dict) -> Dict: + """Merge data from dict_b into dict_a.""" + for key in dict_b.keys(): + if key in dict_a: + dict_a[key].extend(dict_b[key]) + uniquify(dict_a[key]) + else: + dict_a[key] = dict_b[key] + + def merge_matching_frames(self, video: Optional[Video] = None): + """ + Merge `LabeledFrame` objects that are for the same video frame. + + Args: + video: combine for this video; if None, do all videos + Returns: + None """ if video is None: for vid in {lf.video for lf in self.labeled_frames}: self.merge_matching_frames(video=vid) else: - self.labeled_frames = LabeledFrame.merge_frames(self.labeled_frames, video=video) + self.labeled_frames = LabeledFrame.merge_frames( + self.labeled_frames, video=video + ) def to_dict(self, skip_labels: bool = False): """ @@ -671,12 +1044,14 @@ def to_dict(self, skip_labels: bool = False): JSON and HDF5 serialized datasets. Args: - skip_labels: If True, skip labels serialization and just do the metadata. + skip_labels: If True, skip labels serialization and just do the + metadata. Returns: A dict containing the followings top level keys: * version - The version of the dict/json serialization format. - * skeletons - The skeletons associated with these underlying instances. + * skeletons - The skeletons associated with these underlying + instances. * nodes - The nodes that the skeletons represent. * videos - The videos that that the instances occur on. * labels - The labeled frames @@ -688,17 +1063,27 @@ def to_dict(self, skip_labels: bool = False): # FIXME: Update list of nodes # We shouldn't have to do this here, but for some reason we're missing nodes # which are in the skeleton but don't have points (in the first instance?). - self.nodes = list(set(self.nodes).union({node for skeleton in self.skeletons for node in skeleton.nodes})) + self.nodes = list( + set(self.nodes).union( + {node for skeleton in self.skeletons for node in skeleton.nodes} + ) + ) # Register some unstructure hooks since we don't want complete deserialization # of video and skeleton objects present in the labels. We will serialize these # as references to the above constructed lists to limit redundant data in the # json label_cattr = make_instance_cattr() - label_cattr.register_unstructure_hook(Skeleton, lambda x: str(self.skeletons.index(x))) - label_cattr.register_unstructure_hook(Video, lambda x: str(self.videos.index(x))) + label_cattr.register_unstructure_hook( + Skeleton, lambda x: str(self.skeletons.index(x)) + ) + label_cattr.register_unstructure_hook( + Video, lambda x: str(self.videos.index(x)) + ) label_cattr.register_unstructure_hook(Node, lambda x: str(self.nodes.index(x))) - label_cattr.register_unstructure_hook(Track, lambda x: str(self.tracks.index(x))) + label_cattr.register_unstructure_hook( + Track, lambda x: str(self.tracks.index(x)) + ) # Make a converter for the top level skeletons list. idx_to_node = {i: self.nodes[i] for i in range(len(self.nodes))} @@ -707,17 +1092,17 @@ def to_dict(self, skip_labels: bool = False): # Serialize the skeletons, videos, and labels dicts = { - 'version': LABELS_JSON_FILE_VERSION, - 'skeletons': skeleton_cattr.unstructure(self.skeletons), - 'nodes': cattr.unstructure(self.nodes), - 'videos': Video.cattr().unstructure(self.videos), - 'tracks': cattr.unstructure(self.tracks), - 'suggestions': label_cattr.unstructure(self.suggestions), - 'negative_anchors': label_cattr.unstructure(self.negative_anchors) - } + "version": LABELS_JSON_FILE_VERSION, + "skeletons": skeleton_cattr.unstructure(self.skeletons), + "nodes": cattr.unstructure(self.nodes), + "videos": Video.cattr().unstructure(self.videos), + "tracks": cattr.unstructure(self.tracks), + "suggestions": label_cattr.unstructure(self.suggestions), + "negative_anchors": label_cattr.unstructure(self.negative_anchors), + } if not skip_labels: - dicts['labels'] = label_cattr.unstructure(self.labeled_frames) + dicts["labels"] = label_cattr.unstructure(self.labeled_frames) return dicts @@ -727,47 +1112,54 @@ def to_json(self): JSON structured string. Returns: - The JSON representaiton of the string. + The JSON representation of the string. """ # Unstructure the data into dicts and dump to JSON. return json_dumps(self.to_dict()) @staticmethod - def save_json(labels: 'Labels', filename: str, - compress: bool = False, - save_frame_data: bool = False, - frame_data_format: str = 'png'): + def save_json( + labels: "Labels", + filename: str, + compress: bool = False, + save_frame_data: bool = False, + frame_data_format: str = "png", + ): """ Save a Labels instance to a JSON format. Args: labels: The labels dataset to save. filename: The filename to save the data to. - compress: Should the data be zip compressed or not? If True, the JSON will be - compressed using Python's shutil.make_archive command into a PKZIP zip file. If - compress is True then filename will have a .zip appended to it. - save_frame_data: Whether to save the image data for each frame as well. For each - video in the dataset, all frames that have labels will be stored as an imgstore - dataset. If save_frame_data is True then compress will be forced to True since - the archive must contain both the JSON data and image data stored in ImgStores. - frame_data_format: If save_frame_data is True, then this argument is used to set - the data format to use when writing frame data to ImgStore objects. Supported - formats should be: - - * 'pgm', - * 'bmp', - * 'ppm', - * 'tif', - * 'png', - * 'jpg', - * 'npy', - * 'mjpeg/avi', - * 'h264/mkv', - * 'avc1/mp4' - - Note: 'h264/mkv' and 'avc1/mp4' require separate installation of these codecs - on your system. They are excluded from sLEAP because of their GPL license. + compress: Whether the data be zip compressed or not? If True, + the JSON will be compressed using Python's shutil.make_archive + command into a PKZIP zip file. If compress is True then + filename will have a .zip appended to it. + save_frame_data: Whether to save the image data for each frame. + For each video in the dataset, all frames that have labels + will be stored as an imgstore dataset. + If save_frame_data is True then compress will be forced to True + since the archive must contain both the JSON data and image + data stored in ImgStores. + frame_data_format: If save_frame_data is True, then this argument + is used to set the data format to use when writing frame + data to ImgStore objects. Supported formats should be: + + * 'pgm', + * 'bmp', + * 'ppm', + * 'tif', + * 'png', + * 'jpg', + * 'npy', + * 'mjpeg/avi', + * 'h264/mkv', + * 'avc1/mp4' + + Note: 'h264/mkv' and 'avc1/mp4' require separate installation + of these codecs on your system. They are excluded from SLEAP + because of their GPL license. Returns: None @@ -784,20 +1176,26 @@ def save_json(labels: 'Labels', filename: str, # Create a set of new Video objects with imgstore backends. One for each # of the videos. We will only include the labeled frames though. We will # then replace each video with this new video - new_videos = labels.save_frame_data_imgstore(output_dir=tmp_dir, format=frame_data_format) + new_videos = labels.save_frame_data_imgstore( + output_dir=tmp_dir, format=frame_data_format + ) # Make video paths relative for vid in new_videos: tmp_path = vid.filename # Get the parent dir of the YAML file. # Use "/" since this works on Windows and posix - img_store_dir = os.path.basename(os.path.split(tmp_path)[0]) + "/" + os.path.basename(tmp_path) + img_store_dir = ( + os.path.basename(os.path.split(tmp_path)[0]) + + "/" + + os.path.basename(tmp_path) + ) # Change to relative path vid.backend.filename = img_store_dir # Convert to a dict, not JSON yet, because we need to patch up the videos d = labels.to_dict() - d['videos'] = Video.cattr().unstructure(new_videos) + d["videos"] = Video.cattr().unstructure(new_videos) else: d = labels.to_dict() @@ -813,14 +1211,31 @@ def save_json(labels: 'Labels', filename: str, json_dumps(d, full_out_filename) # Create the archive - shutil.make_archive(base_name=filename, root_dir=tmp_dir, format='zip') + shutil.make_archive(base_name=filename, root_dir=tmp_dir, format="zip") # If the user doesn't want to compress, then just write the json to the filename else: json_dumps(d, filename) @classmethod - def from_json(cls, data: Union[str, dict], match_to: Optional['Labels'] = None) -> 'Labels': + def from_json( + cls, data: Union[str, dict], match_to: Optional["Labels"] = None + ) -> "Labels": + """ + Create instance of class from data in dictionary. + + Method is used by other methods that load from JSON. + + Args: + data: Dictionary, deserialized from JSON. + match_to: If given, we'll replace particular objects in the + data dictionary with *matching* objects in the match_to + :class:`Labels` object. This ensures that the newly + instantiated :class:`Labels` can be merged without + duplicate matching objects (e.g., :class:`Video` objects ). + Returns: + A new :class:`Labels` object. + """ # Parse the json string if needed. if type(data) is str: @@ -828,16 +1243,20 @@ def from_json(cls, data: Union[str, dict], match_to: Optional['Labels'] = None) else: dicts = data - dicts['tracks'] = dicts.get('tracks', []) # don't break if json doesn't include tracks + dicts["tracks"] = dicts.get( + "tracks", [] + ) # don't break if json doesn't include tracks # First, deserialize the skeletons, videos, and nodes lists. # The labels reference these so we will need them while deserializing. - nodes = cattr.structure(dicts['nodes'], List[Node]) + nodes = cattr.structure(dicts["nodes"], List[Node]) - idx_to_node = {i:nodes[i] for i in range(len(nodes))} - skeletons = Skeleton.make_cattr(idx_to_node).structure(dicts['skeletons'], List[Skeleton]) - videos = Video.cattr().structure(dicts['videos'], List[Video]) - tracks = cattr.structure(dicts['tracks'], List[Track]) + idx_to_node = {i: nodes[i] for i in range(len(nodes))} + skeletons = Skeleton.make_cattr(idx_to_node).structure( + dicts["skeletons"], List[Skeleton] + ) + videos = Video.cattr().structure(dicts["videos"], List[Video]) + tracks = cattr.structure(dicts["tracks"], List[Track]) # if we're given a Labels object to match, use its objects when they match if match_to is not None: @@ -854,50 +1273,88 @@ def from_json(cls, data: Union[str, dict], match_to: Optional['Labels'] = None) for idx, vid in enumerate(videos): for old_vid in match_to.videos: # compare last three parts of path - weak_match = vid.filename.split("/")[-3:] == old_vid.filename.split("/")[-3:] - if vid.filename == old_vid.filename or weak_match: + if vid.filename == old_vid.filename or weak_filename_match( + vid.filename, old_vid.filename + ): # use video from match videos[idx] = old_vid break if "suggestions" in dicts: suggestions_cattr = cattr.Converter() - suggestions_cattr.register_structure_hook(Video, lambda x,type: videos[int(x)]) - suggestions = suggestions_cattr.structure(dicts['suggestions'], Dict[Video, List]) + suggestions_cattr.register_structure_hook( + Video, lambda x, type: videos[int(x)] + ) + suggestions = suggestions_cattr.structure( + dicts["suggestions"], Dict[Video, List] + ) else: suggestions = dict() if "negative_anchors" in dicts: negative_anchors_cattr = cattr.Converter() - negative_anchors_cattr.register_structure_hook(Video, lambda x,type: videos[int(x)]) - negative_anchors = negative_anchors_cattr.structure(dicts['negative_anchors'], Dict[Video, List]) + negative_anchors_cattr.register_structure_hook( + Video, lambda x, type: videos[int(x)] + ) + negative_anchors = negative_anchors_cattr.structure( + dicts["negative_anchors"], Dict[Video, List] + ) else: negative_anchors = dict() # If there is actual labels data, get it. - if 'labels' in dicts: + if "labels" in dicts: label_cattr = make_instance_cattr() - label_cattr.register_structure_hook(Skeleton, lambda x,type: skeletons[int(x)]) - label_cattr.register_structure_hook(Video, lambda x,type: videos[int(x)]) - label_cattr.register_structure_hook(Node, lambda x,type: x if isinstance(x,Node) else nodes[int(x)]) - label_cattr.register_structure_hook(Track, lambda x, type: None if x is None else tracks[int(x)]) - - labels = label_cattr.structure(dicts['labels'], List[LabeledFrame]) + label_cattr.register_structure_hook( + Skeleton, lambda x, type: skeletons[int(x)] + ) + label_cattr.register_structure_hook(Video, lambda x, type: videos[int(x)]) + label_cattr.register_structure_hook( + Node, lambda x, type: x if isinstance(x, Node) else nodes[int(x)] + ) + label_cattr.register_structure_hook( + Track, lambda x, type: None if x is None else tracks[int(x)] + ) + + labels = label_cattr.structure(dicts["labels"], List[LabeledFrame]) else: labels = [] - return cls(labeled_frames=labels, - videos=videos, - skeletons=skeletons, - nodes=nodes, - suggestions=suggestions, - negative_anchors=negative_anchors, - tracks=tracks) + return cls( + labeled_frames=labels, + videos=videos, + skeletons=skeletons, + nodes=nodes, + suggestions=suggestions, + negative_anchors=negative_anchors, + tracks=tracks, + ) @classmethod - def load_json(cls, filename: str, - video_callback=None, - match_to: Optional['Labels'] = None): + def load_json( + cls, + filename: str, + video_callback: Optional[Callable] = None, + match_to: Optional["Labels"] = None, + ) -> "Labels": + """ + Deserialize JSON file as new :class:`Labels` instance. + + Args: + filename: Path to JSON file. + video_callback: A callback function that which can modify + video paths before we try to create the corresponding + :class:`Video` objects. Usually you'll want to pass + a callback created by :meth:`make_video_callback` + or :meth:`make_gui_video_callback`. + match_to: If given, we'll replace particular objects in the + data dictionary with *matching* objects in the match_to + :class:`Labels` object. This ensures that the newly + instantiated :class:`Labels` can be merged without + duplicate matching objects (e.g., :class:`Video` objects ). + Returns: + A new :class:`Labels` object. + """ tmp_dir = None @@ -906,8 +1363,10 @@ def load_json(cls, filename: str, # Make a tmpdir, located in the directory that the file exists, to unzip # its contents. - tmp_dir = os.path.join(os.path.dirname(filename), - f"tmp_{os.path.basename(filename)}") + tmp_dir = os.path.join( + os.path.dirname(filename), + f"tmp_{os.getpid()}_{os.path.basename(filename)}", + ) if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir, ignore_errors=True) try: @@ -915,7 +1374,7 @@ def load_json(cls, filename: str, except FileExistsError: pass - #tmp_dir = tempfile.mkdtemp(dir=os.path.dirname(filename)) + # tmp_dir = tempfile.mkdtemp(dir=os.path.dirname(filename)) try: @@ -928,10 +1387,16 @@ def load_json(cls, filename: str, # We can now open the JSON file, save the zip file and # replace file with the first JSON file we find in the archive. - json_files = [os.path.join(tmp_dir, file) for file in os.listdir(tmp_dir) if file.endswith(".json")] + json_files = [ + os.path.join(tmp_dir, file) + for file in os.listdir(tmp_dir) + if file.endswith(".json") + ] if len(json_files) == 0: - raise ValueError(f"No JSON file found inside {filename}. Are you sure this is a valid sLEAP dataset.") + raise ValueError( + f"No JSON file found inside {filename}. Are you sure this is a valid sLEAP dataset." + ) filename = json_files[0] @@ -941,7 +1406,7 @@ def load_json(cls, filename: str, raise # Open and parse the JSON in filename - with open(filename, 'r') as file: + with open(filename, "r") as file: # FIXME: Peek into the json to see if there is version string. # We do this to tell apart old JSON data from leap_dev vs the @@ -957,7 +1422,9 @@ def load_json(cls, filename: str, # Replace local video paths (for imagestore) if tmp_dir: for vid in dicts["videos"]: - vid["backend"]["filename"] = os.path.join(tmp_dir, vid["backend"]["filename"]) + vid["backend"]["filename"] = os.path.join( + tmp_dir, vid["backend"]["filename"] + ) # Use the callback if given to handle missing videos if callable(video_callback): @@ -983,40 +1450,50 @@ def load_json(cls, filename: str, except Exception as ex: # Ok, we give up, where the hell are these videos! - raise # Re-raise. + raise # Re-raise. finally: os.chdir(cwd) # Make sure to change back if we have problems. return labels else: - return load_labels_json_old(data_path=filename, parsed_json=dicts) + frames = load_labels_json_old(data_path=filename, parsed_json=dicts) + return Labels(frames) @staticmethod - def save_hdf5(labels: 'Labels', filename: str, - append: bool = False, - save_frame_data: bool = False): + def save_hdf5( + labels: "Labels", + filename: str, + append: bool = False, + save_frame_data: bool = False, + frame_data_format: str = "png", + ): """ Serialize the labels dataset to an HDF5 file. Args: - labels: The Labels dataset to save + labels: The :class:`Labels` dataset to save filename: The file to serialize the dataset to. - append: Whether to append these labeled frames to the file or - not. - save_frame_data: Whether to save the image frame data for any - labeled frame as well. This is useful for uploading the HDF5 for - model training when video files are to large to move. This will only - save video frames that have some labeled instances. + append: Whether to append these labeled frames to the file + or not. + save_frame_data: Whether to save the image frame data for + any labeled frame as well. This is useful for uploading + the HDF5 for model training when video files are to + large to move. This will only save video frames that + have some labeled instances. + frame_data_format: If save_frame_data is True, then this argument + is used to set the data format to use when encoding images + saved in HDF5. Supported formats include: + + * "" for no encoding (ndarray) + * "png" + * "jpg" + * anything else supported by `cv2.imencode` Returns: None """ - # FIXME: Need to implement this. - if save_frame_data: - raise NotImplementedError('Saving frame data is not implemented yet with HDF5 Labels datasets.') - # Delete the file if it exists, we want to start from scratch since # h5py truncates the file which seems to not actually delete data # from the file. Don't if we are appending of course. @@ -1026,16 +1503,30 @@ def save_hdf5(labels: 'Labels', filename: str, # Serialize all the meta-data to JSON. d = labels.to_dict(skip_labels=True) - with h5.File(filename, 'a') as f: + if save_frame_data: + new_videos = labels.save_frame_data_hdf5(filename, frame_data_format) + + # Replace path to video file with "." (which indicates that the + # video is in the same file as the HDF5 labels dataset). + # Otherwise, the video paths will break if the HDF5 labels + # dataset file is moved. + for vid in new_videos: + vid.backend.filename = "." + + d["videos"] = Video.cattr().unstructure(new_videos) + + with h5.File(filename, "a") as f: # Add all the JSON metadata - meta_group = f.require_group('metadata') + meta_group = f.require_group("metadata") # If we are appending and there already exists JSON metadata - if append and 'json' in meta_group.attrs: + if append and "json" in meta_group.attrs: # Otherwise, we need to read the JSON and append to the lists - old_labels = Labels.from_json(meta_group.attrs['json'].tostring().decode()) + old_labels = Labels.from_json( + meta_group.attrs["json"].tostring().decode() + ) # A function to join to list but only include new non-dupe entries # from the right hand list. @@ -1072,50 +1563,73 @@ def append_unique(old, new): d = labels.to_dict(skip_labels=True) # Output the dict to JSON - meta_group.attrs['json'] = np.string_(json_dumps(d)) + meta_group.attrs["json"] = np.string_(json_dumps(d)) # FIXME: We can probably construct these from attrs fields # We will store Instances and PredcitedInstances in the same # table. instance_type=0 or Instance and instance_type=1 for # PredictedInstance, score will be ignored for Instances. - instance_dtype = np.dtype([('instance_id', 'i8'), - ('instance_type', 'u1'), - ('frame_id', 'u8'), - ('skeleton', 'u4'), - ('track', 'i4'), - ('from_predicted', 'i8'), - ('score', 'f4'), - ('point_id_start', 'u8'), - ('point_id_end', 'u8')]) - frame_dtype = np.dtype([('frame_id', 'u8'), - ('video', 'u4'), - ('frame_idx', 'u8'), - ('instance_id_start', 'u8'), - ('instance_id_end', 'u8')]) + instance_dtype = np.dtype( + [ + ("instance_id", "i8"), + ("instance_type", "u1"), + ("frame_id", "u8"), + ("skeleton", "u4"), + ("track", "i4"), + ("from_predicted", "i8"), + ("score", "f4"), + ("point_id_start", "u8"), + ("point_id_end", "u8"), + ] + ) + frame_dtype = np.dtype( + [ + ("frame_id", "u8"), + ("video", "u4"), + ("frame_idx", "u8"), + ("instance_id_start", "u8"), + ("instance_id_end", "u8"), + ] + ) num_instances = len(labels.all_instances) max_skeleton_size = max([len(s.nodes) for s in labels.skeletons]) # Initialize data arrays for serialization points = np.zeros(num_instances * max_skeleton_size, dtype=Point.dtype) - pred_points = np.zeros(num_instances * max_skeleton_size, dtype=PredictedPoint.dtype) + pred_points = np.zeros( + num_instances * max_skeleton_size, dtype=PredictedPoint.dtype + ) instances = np.zeros(num_instances, dtype=instance_dtype) frames = np.zeros(len(labels), dtype=frame_dtype) # Pre compute some structures to make serialization faster - skeleton_to_idx = {skeleton: labels.skeletons.index(skeleton) for skeleton in labels.skeletons} - track_to_idx = {track: labels.tracks.index(track) for track in labels.tracks} + skeleton_to_idx = { + skeleton: labels.skeletons.index(skeleton) + for skeleton in labels.skeletons + } + track_to_idx = { + track: labels.tracks.index(track) for track in labels.tracks + } track_to_idx[None] = -1 - video_to_idx = {video: labels.videos.index(video) for video in labels.videos} + video_to_idx = { + video: labels.videos.index(video) for video in labels.videos + } instance_type_to_idx = {Instance: 0, PredictedInstance: 1} + # Each instance we create will have and index in the dataset, keep track of + # these so we can quickly add from_predicted links on a second pass. + instance_to_idx = {} + instances_with_from_predicted = [] + instances_from_predicted = [] + # If we are appending, we need look inside to see what frame, instance, and point # ids we need to start from. This gives us offsets to use. - if append and 'points' in f: - point_id_offset = f['points'].shape[0] - pred_point_id_offset = f['pred_points'].shape[0] - instance_id_offset = f['instances'][-1]['instance_id'] + 1 - frame_id_offset = int(f['frames'][-1]['frame_id']) + 1 + if append and "points" in f: + point_id_offset = f["points"].shape[0] + pred_point_id_offset = f["pred_points"].shape[0] + instance_id_offset = f["instances"][-1]["instance_id"] + 1 + frame_id_offset = int(f["frames"][-1]["frame_id"]) + 1 else: point_id_offset = 0 pred_point_id_offset = 0 @@ -1125,14 +1639,22 @@ def append_unique(old, new): point_id = 0 pred_point_id = 0 instance_id = 0 - frame_id = 0 - all_from_predicted = [] - from_predicted_id = 0 + for frame_id, label in enumerate(labels): - frames[frame_id] = (frame_id+frame_id_offset, video_to_idx[label.video], label.frame_idx, - instance_id+instance_id_offset, instance_id+instance_id_offset+len(label.instances)) + frames[frame_id] = ( + frame_id + frame_id_offset, + video_to_idx[label.video], + label.frame_idx, + instance_id + instance_id_offset, + instance_id + instance_id_offset + len(label.instances), + ) for instance in label.instances: - parray = instance.points_array(copy=False, full=True) + + # Add this instance to our lookup structure we will need for from_predicted + # links + instance_to_idx[instance] = instance_id + + parray = instance.get_points_array(copy=False, full=True) instance_type = type(instance) # Check whether we are working with a PredictedInstance or an Instance. @@ -1146,61 +1668,122 @@ def append_unique(old, new): # Keep track of any from_predicted instance links, we will insert the # correct instance_id in the dataset after we are done. if instance.from_predicted: - all_from_predicted.append(instance.from_predicted) - from_predicted_id = from_predicted_id + 1 + instances_with_from_predicted.append(instance_id) + instances_from_predicted.append(instance.from_predicted) # Copy all the data - instances[instance_id] = (instance_id+instance_id_offset, - instance_type_to_idx[instance_type], - frame_id, - skeleton_to_idx[instance.skeleton], - track_to_idx[instance.track], - -1, - score, - pid, pid + len(parray)) + instances[instance_id] = ( + instance_id + instance_id_offset, + instance_type_to_idx[instance_type], + frame_id, + skeleton_to_idx[instance.skeleton], + track_to_idx[instance.track], + -1, + score, + pid, + pid + len(parray), + ) # If these are predicted points, copy them to the predicted point array # otherwise, use the normal point array if type(parray) is PredictedPointArray: - pred_points[pred_point_id:pred_point_id + len(parray)] = parray + pred_points[ + pred_point_id : pred_point_id + len(parray) + ] = parray pred_point_id = pred_point_id + len(parray) else: - points[point_id:point_id + len(parray)] = parray + points[point_id : point_id + len(parray)] = parray point_id = point_id + len(parray) instance_id = instance_id + 1 + # Add from_predicted links + for instance_id, from_predicted in zip( + instances_with_from_predicted, instances_from_predicted + ): + try: + instances[instance_id]["from_predicted"] = instance_to_idx[ + from_predicted + ] + except KeyError: + # If we haven't encountered the from_predicted instance yet then don't save the link. + # It’s possible for a user to create a regular instance from a predicted instance and then + # delete all predicted instances from the file, but in this case I don’t think there’s any reason + # to remember which predicted instance the regular instance came from. + pass + # We pre-allocated our points array with max possible size considering the max # skeleton size, drop any unused points. points = points[0:point_id] pred_points = pred_points[0:pred_point_id] # Create datasets if we need to - if append and 'points' in f: - f['points'].resize((f["points"].shape[0] + points.shape[0]), axis = 0) - f['points'][-points.shape[0]:] = points - f['pred_points'].resize((f["pred_points"].shape[0] + pred_points.shape[0]), axis=0) - f['pred_points'][-pred_points.shape[0]:] = pred_points - f['instances'].resize((f["instances"].shape[0] + instances.shape[0]), axis=0) - f['instances'][-instances.shape[0]:] = instances - f['frames'].resize((f["frames"].shape[0] + frames.shape[0]), axis=0) - f['frames'][-frames.shape[0]:] = frames + if append and "points" in f: + f["points"].resize((f["points"].shape[0] + points.shape[0]), axis=0) + f["points"][-points.shape[0] :] = points + f["pred_points"].resize( + (f["pred_points"].shape[0] + pred_points.shape[0]), axis=0 + ) + f["pred_points"][-pred_points.shape[0] :] = pred_points + f["instances"].resize( + (f["instances"].shape[0] + instances.shape[0]), axis=0 + ) + f["instances"][-instances.shape[0] :] = instances + f["frames"].resize((f["frames"].shape[0] + frames.shape[0]), axis=0) + f["frames"][-frames.shape[0] :] = frames else: - f.create_dataset("points", data=points, maxshape=(None,), dtype=Point.dtype) - f.create_dataset("pred_points", data=pred_points, maxshape=(None,), dtype=PredictedPoint.dtype) - f.create_dataset("instances", data=instances, maxshape=(None,), dtype=instance_dtype) - f.create_dataset("frames", data=frames, maxshape=(None,), dtype=frame_dtype) + f.create_dataset( + "points", data=points, maxshape=(None,), dtype=Point.dtype + ) + f.create_dataset( + "pred_points", + data=pred_points, + maxshape=(None,), + dtype=PredictedPoint.dtype, + ) + f.create_dataset( + "instances", data=instances, maxshape=(None,), dtype=instance_dtype + ) + f.create_dataset( + "frames", data=frames, maxshape=(None,), dtype=frame_dtype + ) @classmethod - def load_hdf5(cls, filename: str, - video_callback=None, - match_to: Optional['Labels'] = None): + def load_hdf5( + cls, filename: str, video_callback=None, match_to: Optional["Labels"] = None + ): + """ + Deserialize HDF5 file as new :class:`Labels` instance. + + Args: + filename: Path to HDF5 file. + video_callback: A callback function that which can modify + video paths before we try to create the corresponding + :class:`Video` objects. Usually you'll want to pass + a callback created by :meth:`make_video_callback` + or :meth:`make_gui_video_callback`. + match_to: If given, we'll replace particular objects in the + data dictionary with *matching* objects in the match_to + :class:`Labels` object. This ensures that the newly + instantiated :class:`Labels` can be merged without + duplicate matching objects (e.g., :class:`Video` objects ). - with h5.File(filename, 'r') as f: + Returns: + A new :class:`Labels` object. + """ + with h5.File(filename, "r") as f: # Extract the Labels JSON metadata and create Labels object with just # this metadata. - dicts = json_loads(f.require_group('metadata').attrs['json'].tostring().decode()) + dicts = json_loads( + f.require_group("metadata").attrs["json"].tostring().decode() + ) + + # Video path "." means the video is saved in same file as labels, + # so replace these paths. + for video_item in dicts["videos"]: + if video_item["backend"]["filename"] == ".": + video_item["backend"]["filename"] = filename # Use the callback if given to handle missing videos if callable(video_callback): @@ -1208,52 +1791,78 @@ def load_hdf5(cls, filename: str, labels = cls.from_json(dicts, match_to=match_to) - frames_dset = f['frames'][:] - instances_dset = f['instances'][:] - points_dset = f['points'][:] - pred_points_dset = f['pred_points'][:] + frames_dset = f["frames"][:] + instances_dset = f["instances"][:] + points_dset = f["points"][:] + pred_points_dset = f["pred_points"][:] # Rather than instantiate a bunch of Point\PredictedPoint objects, we will # use inplace numpy recarrays. This will save a lot of time and memory # when reading things in. points = PointArray(buf=points_dset, shape=len(points_dset)) - pred_points = PredictedPointArray(buf=pred_points_dset, shape=len(pred_points_dset)) + pred_points = PredictedPointArray( + buf=pred_points_dset, shape=len(pred_points_dset) + ) # Extend the tracks list with a None track. We will signify this with a -1 in the # data which will map to last element of tracks tracks = labels.tracks.copy() tracks.extend([None]) + # A dict to keep track of instances that have a from_predicted link. The key is the + # instance and the value is the index of the instance. + from_predicted_lookup = {} + # Create the instances instances = [] for i in instances_dset: - track = tracks[i['track']] - skeleton = labels.skeletons[i['skeleton']] - - if i['instance_type'] == 0: # Instance - instance = Instance(skeleton=skeleton, track=track, - points=points[i['point_id_start']:i['point_id_end']]) - else: # PredictedInstance - instance = PredictedInstance(skeleton=skeleton, track=track, - points=pred_points[i['point_id_start']:i['point_id_end']], - score=i['score']) + track = tracks[i["track"]] + skeleton = labels.skeletons[i["skeleton"]] + + if i["instance_type"] == 0: # Instance + instance = Instance( + skeleton=skeleton, + track=track, + points=points[i["point_id_start"] : i["point_id_end"]], + ) + else: # PredictedInstance + instance = PredictedInstance( + skeleton=skeleton, + track=track, + points=pred_points[i["point_id_start"] : i["point_id_end"]], + score=i["score"], + ) instances.append(instance) + if i["from_predicted"] != -1: + from_predicted_lookup[instance] = i["from_predicted"] + + # Make a second pass to add any from_predicted links + for instance, from_predicted_idx in from_predicted_lookup.items(): + instance.from_predicted = instances[from_predicted_idx] + # Create the labeled frames - frames = [LabeledFrame(video=labels.videos[frame['video']], - frame_idx=frame['frame_idx'], - instances=instances[frame['instance_id_start']:frame['instance_id_end']]) - for i, frame in enumerate(frames_dset)] + frames = [ + LabeledFrame( + video=labels.videos[frame["video"]], + frame_idx=frame["frame_idx"], + instances=instances[ + frame["instance_id_start"] : frame["instance_id_end"] + ], + ) + for i, frame in enumerate(frames_dset) + ] labels.labeled_frames = frames # Do the stuff that should happen after we have labeled frames - labels._update_lookup_cache() + labels._build_lookup_caches() return labels @classmethod def load_file(cls, filename: str, *args, **kwargs): + """Load file, detecting format from filename.""" if filename.endswith((".h5", ".hdf5")): return cls.load_hdf5(filename, *args, **kwargs) elif filename.endswith((".json", ".json.zip")): @@ -1266,31 +1875,73 @@ def load_file(cls, filename: str, *args, **kwargs): else: raise ValueError(f"Cannot detect filetype for {filename}") - def save_frame_data_imgstore(self, output_dir: str = './', format: str = 'png', all_labels: bool = False): + @classmethod + def save_file( + cls, labels: "Labels", filename: str, default_suffix: str = "", *args, **kwargs + ): + """Save file, detecting format from filename. + + Args: + labels: The dataset to save. + filename: Path where we'll save it. We attempt to detect format + from the suffix (e.g., ".json"). + default_suffix: If we can't detect valid suffix on filename, + we can add default suffix to filename (and use corresponding + format). Doesn't need to have "." before file extension. + + Raises: + ValueError: If cannot detect valid filetype. + + Returns: + None. + """ + if not filename.endswith((".json", ".zip", ".h5")) and default_suffix: + filename += f".{default_suffix}" + if filename.endswith((".json", ".zip")): + compress = filename.endswith(".zip") + cls.save_json(labels=labels, filename=filename, compress=compress, **kwargs) + elif filename.endswith(".h5"): + cls.save_hdf5(labels=labels, filename=filename, **kwargs) + else: + raise ValueError(f"Cannot detect filetype for {filename}") + + def save_frame_data_imgstore( + self, output_dir: str = "./", format: str = "png", all_labels: bool = False + ): """ - Write all labeled frames from all videos to a collection of imgstore datasets. - This only writes frames that have been labeled. Videos without any labeled frames - will be included as empty imgstores. + Write all labeled frames from all videos to imgstore datasets. + + This only writes frames that have been labeled. Videos without + any labeled frames will be included as empty imgstores. Args: - output_dir: - format: The image format to use for the data. png for lossless, jpg for lossy. - Other imgstore formats will probably work as well but have not been tested. + output_dir: Path to directory which will contain imgstores. + format: The image format to use for the data. + Use "png" for lossless, "jpg" for lossy. + Other imgstore formats will probably work as well but + have not been tested. all_labels: Include any labeled frames, not just the frames - we'll use for training (i.e., those with Instances). + we'll use for training (i.e., those with `Instance` objects ). Returns: - A list of ImgStoreVideo objects that represent the stored frames. + A list of :class:`ImgStoreVideo` objects with the stored + frames. """ # For each label imgstore_vids = [] for v_idx, v in enumerate(self.videos): - frame_nums = [lf.frame_idx for lf in self.labeled_frames - if v == lf.video - and (all_labels or lf.has_user_instances)] + frame_nums = [ + lf.frame_idx + for lf in self.labeled_frames + if v == lf.video and (all_labels or lf.has_user_instances) + ] - frames_filename = os.path.join(output_dir, f'frame_data_vid{v_idx}') - vid = v.to_imgstore(path=frames_filename, frame_numbers=frame_nums, format=format) + # Join with "/" instead of os.path.join() since we want + # path to work on Windows and Posix systems + frames_filename = output_dir + f"/frame_data_vid{v_idx}" + vid = v.to_imgstore( + path=frames_filename, frame_numbers=frame_nums, format=format + ) # Close the video for now vid.close() @@ -1299,9 +1950,43 @@ def save_frame_data_imgstore(self, output_dir: str = './', format: str = 'png', return imgstore_vids + def save_frame_data_hdf5( + self, output_path: str, format: str = "png", all_labels: bool = False + ): + """ + Write labeled frames from all videos to hdf5 file. + + Args: + output_path: Path to HDF5 file. + format: The image format to use for the data. Defaults to png. + all_labels: Include any labeled frames, not just the frames + we'll use for training (i.e., those with Instances). + + Returns: + A list of :class:`HDF5Video` objects with the stored frames. + """ + new_vids = [] + for v_idx, v in enumerate(self.videos): + frame_nums = [ + lf.frame_idx + for lf in self.labeled_frames + if v == lf.video and (all_labels or lf.has_user_instances) + ] + + vid = v.to_hdf5( + path=output_path, + dataset=f"video{v_idx}", + format=format, + frame_numbers=frame_nums, + ) + vid.close() + new_vids.append(vid) + + return new_vids @staticmethod def _unwrap_mat_scalar(a): + """Extract single value from nested MATLAB file data.""" if a.shape == (1,): return Labels._unwrap_mat_scalar(a[0]) else: @@ -1309,12 +1994,20 @@ def _unwrap_mat_scalar(a): @staticmethod def _unwrap_mat_array(a): + """Extract list of values from nested MATLAB file data.""" b = a[0][0] c = [Labels._unwrap_mat_scalar(x) for x in b] return c @classmethod - def load_mat(cls, filename): + def load_mat(cls, filename: str) -> "Labels": + """Load LEAP MATLAB file as dataset. + + Args: + filename: Path to csv file. + Returns: + The :class:`Labels` dataset. + """ mat_contents = sio.loadmat(filename) box_path = Labels._unwrap_mat_scalar(mat_contents["boxPath"]) @@ -1322,11 +2015,13 @@ def load_mat(cls, filename): # If the video file isn't found, try in the same dir as the mat file if not os.path.exists(box_path): file_dir = os.path.dirname(filename) - box_path_name = box_path.split("\\")[-1] # assume windows path + box_path_name = box_path.split("\\")[-1] # assume windows path box_path = os.path.join(file_dir, box_path_name) if os.path.exists(box_path): - vid = Video.from_hdf5(dataset="box", filename=box_path, input_format="channels_first") + vid = Video.from_hdf5( + dataset="box", filename=box_path, input_format="channels_first" + ) else: vid = None @@ -1336,12 +2031,12 @@ def load_mat(cls, filename): edges_ = mat_contents["skeleton"]["edges"] points_ = mat_contents["positions"] - edges_ = edges_ - 1 # convert matlab 1-indexing to python 0-indexing + edges_ = edges_ - 1 # convert matlab 1-indexing to python 0-indexing nodes = Labels._unwrap_mat_array(nodes_) edges = Labels._unwrap_mat_array(edges_) - nodes = list(map(str, nodes)) # convert np._str to str + nodes = list(map(str, nodes)) # convert np._str to str sk = Skeleton(name=filename) sk.add_nodes(nodes) @@ -1352,14 +2047,14 @@ def load_mat(cls, filename): node_count, _, frame_count = points_.shape for i in range(frame_count): - new_inst = Instance(skeleton = sk) + new_inst = Instance(skeleton=sk) for node_idx, node in enumerate(nodes): x = points_[node_idx][0][i] y = points_[node_idx][1][i] new_inst[node] = Point(x, y) - if len(new_inst.points()): + if len(new_inst.points): new_frame = LabeledFrame(video=vid, frame_idx=i) - new_frame.instances = new_inst, + new_frame.instances = (new_inst,) labeled_frames.append(new_frame) labels = cls(labeled_frames=labeled_frames, videos=[vid], skeletons=[sk]) @@ -1367,7 +2062,14 @@ def load_mat(cls, filename): return labels @classmethod - def load_deeplabcut_csv(cls, filename): + def load_deeplabcut_csv(cls, filename: str) -> "Labels": + """Load DeepLabCut csv file as dataset. + + Args: + filename: Path to csv file. + Returns: + The :class:`Labels` dataset. + """ # At the moment we don't need anything from the config file, # but the code to read it is here in case we do in the future. @@ -1395,7 +2097,7 @@ def load_deeplabcut_csv(cls, filename): # x2 = config['x2'] # y2 = config['y2'] - data = pd.read_csv(filename, header=[1,2]) + data = pd.read_csv(filename, header=[1, 2]) # Create the skeleton from the list of nodes in the csv file # Note that DeepLabCut doesn't have edges, so these will have to be added by user later @@ -1408,7 +2110,7 @@ def load_deeplabcut_csv(cls, filename): # This may not be ideal for large projects, since we're reading in # each image and then writing it out in a new directory. - img_files = data.ix[:,0] # get list of all images + img_files = data.ix[:, 0] # get list of all images # the image filenames in the csv may not match where the user has them # so we'll change the directory to match where the user has the csv @@ -1435,7 +2137,7 @@ def fix_img_path(img_dir, img_filename): # get points for each node instance_points = dict() for node in node_names: - x, y = data[(node, 'x')][i], data[(node, 'y')][i] + x, y = data[(node, "x")][i], data[(node, "y")][i] instance_points[node] = Point(x, y) # create instance with points (we can assume there's only one instance per frame) instance = Instance(skeleton=skeleton, points=instance_points) @@ -1446,8 +2148,23 @@ def fix_img_path(img_dir, img_filename): return cls(labels) @classmethod - def make_video_callback(cls, search_paths=None): + def make_video_callback(cls, search_paths: Optional[List] = None) -> Callable: + """ + Create a non-GUI callback for finding missing videos. + + The callback can be used while loading a saved project and + allows the user to find videos which have been moved (or have + paths from a different system). + + Args: + search_paths: If specified, this is a list of paths where + we'll automatically try to find the missing videos. + + Returns: + The callback function. + """ search_paths = search_paths or [] + def video_callback(video_list, new_paths=search_paths): # Check each video for video_item in video_list: @@ -1455,7 +2172,6 @@ def video_callback(video_list, new_paths=search_paths): current_filename = video_item["backend"]["filename"] # check if we can find video if not os.path.exists(current_filename): - is_found = False current_basename = os.path.basename(current_filename) # handle unix, windows, or mixed paths @@ -1474,19 +2190,32 @@ def video_callback(video_list, new_paths=search_paths): if os.path.exists(check_path): # we found the file in a different directory video_item["backend"]["filename"] = check_path - is_found = True break + return video_callback @classmethod - def make_gui_video_callback(cls, search_paths): + def make_gui_video_callback(cls, search_paths: Optional[List] = None) -> Callable: + """ + Create a callback with GUI for finding missing videos. + + The callback can be used while loading a saved project and + allows the user to find videos which have been moved (or have + paths from a different system). + + Args: + search_paths: If specified, this is a list of paths where + we'll automatically try to find the missing videos. + + Returns: + The callback function. + """ search_paths = search_paths or [] + def gui_video_callback(video_list, new_paths=search_paths): import os from PySide2.QtWidgets import QFileDialog, QMessageBox - has_shown_prompt = False # have we already alerted user about missing files? - basename_list = [] # Check each video @@ -1519,120 +2248,30 @@ def gui_video_callback(video_list, new_paths=search_paths): break # if we found this file, then move on to the next file - if is_found: continue + if is_found: + continue # Since we couldn't find the file on our own, prompt the user. print(f"Unable to find: {current_filename}") - QMessageBox(text=f"We're unable to locate one or more video files for this project. Please locate {current_filename}.").exec_() - has_shown_prompt = True + QMessageBox( + text=f"We're unable to locate one or more video files for this project. Please locate {current_filename}." + ).exec_() current_root, current_ext = os.path.splitext(current_basename) caption = f"Please locate {current_basename}..." - filters = [f"{current_root} file (*{current_ext})", "Any File (*.*)"] + filters = [ + f"{current_root} file (*{current_ext})", + "Any File (*.*)", + ] dir = None if len(new_paths) == 0 else new_paths[-1] - new_filename, _ = QFileDialog.getOpenFileName(None, dir=dir, caption=caption, filter=";;".join(filters)) + new_filename, _ = QFileDialog.getOpenFileName( + None, dir=dir, caption=caption, filter=";;".join(filters) + ) # if we got an answer, then update filename for video if len(new_filename): video_item["backend"]["filename"] = new_filename # keep track of the directory chosen by user new_paths.append(os.path.dirname(new_filename)) basename_list.append(current_basename) - return gui_video_callback - - -def load_labels_json_old(data_path: str, parsed_json: dict = None, - adjust_matlab_indexing: bool = True, - fix_rel_paths: bool = True) -> Labels: - """ - Simple utitlity code to load data from Talmo's old JSON format into newer - Labels object. - - Args: - data_path: The path to the JSON file. - parsed_json: The parsed json if already loaded. Save some time if already parsed. - adjust_matlab_indexing: Do we need to adjust indexing from MATLAB. - fix_rel_paths: Fix paths to videos to absolute paths. - Returns: - A newly constructed Labels object. - """ - if parsed_json is None: - data = json_loads(open(data_path).read()) - else: - data = parsed_json - - videos = pd.DataFrame(data["videos"]) - instances = pd.DataFrame(data["instances"]) - points = pd.DataFrame(data["points"]) - predicted_instances = pd.DataFrame(data["predicted_instances"]) - predicted_points = pd.DataFrame(data["predicted_points"]) - - if adjust_matlab_indexing: - instances.frameIdx -= 1 - points.frameIdx -= 1 - predicted_instances.frameIdx -= 1 - predicted_points.frameIdx -= 1 - - points.node -= 1 - predicted_points.node -= 1 - - points.x -= 1 - predicted_points.x -= 1 - - points.y -= 1 - predicted_points.y -= 1 - - skeleton = Skeleton() - skeleton.add_nodes(data["skeleton"]["nodeNames"]) - edges = data["skeleton"]["edges"] - if adjust_matlab_indexing: - edges = np.array(edges) - 1 - for (src_idx, dst_idx) in edges: - skeleton.add_edge(data["skeleton"]["nodeNames"][src_idx], data["skeleton"]["nodeNames"][dst_idx]) - - if fix_rel_paths: - for i, row in videos.iterrows(): - p = row.filepath - if not os.path.exists(p): - p = os.path.join(os.path.dirname(data_path), p) - if os.path.exists(p): - videos.at[i, "filepath"] = p - - # Make the video objects - video_objects = {} - for i, row in videos.iterrows(): - if videos.at[i, "format"] == "media": - vid = Video.from_media(videos.at[i, "filepath"]) - else: - vid = Video.from_hdf5(filename=videos.at[i, "filepath"], dataset=videos.at[i, "dataset"]) - - video_objects[videos.at[i, "id"]] = vid - - # A function to get all the instances for a particular video frame - def get_frame_instances(video_id, frame_idx): - is_in_frame = (points["videoId"] == video_id) & (points["frameIdx"] == frame_idx) - if not is_in_frame.any(): - return [] - - instances = [] - frame_instance_ids = np.unique(points["instanceId"][is_in_frame]) - for i, instance_id in enumerate(frame_instance_ids): - is_instance = is_in_frame & (points["instanceId"] == instance_id) - instance_points = {data["skeleton"]["nodeNames"][n]: Point(x, y, visible=v) for x, y, n, v in - zip(*[points[k][is_instance] for k in ["x", "y", "node", "visible"]])} - - instance = Instance(skeleton=skeleton, points=instance_points) - instances.append(instance) - - return instances - - # Get the unique labeled frames and construct a list of LabeledFrame objects for them. - frame_keys = list({(videoId, frameIdx) for videoId, frameIdx in zip(points["videoId"], points["frameIdx"])}) - frame_keys.sort() - labels = [] - for videoId, frameIdx in frame_keys: - label = LabeledFrame(video=video_objects[videoId], frame_idx=frameIdx, - instances = get_frame_instances(videoId, frameIdx)) - labels.append(label) - - return Labels(labels) + return gui_video_callback diff --git a/sleap/io/legacy.py b/sleap/io/legacy.py index 4f6c36293..ebcb38256 100644 --- a/sleap/io/legacy.py +++ b/sleap/io/legacy.py @@ -1,32 +1,45 @@ +""" +Module for legacy LEAP dataset. +""" import json import os import numpy as np import pandas as pd -from .dataset import Labels -from .video import Video +from typing import List -from ..instance import LabeledFrame, PredictedPoint, PredictedInstance -from ..skeleton import Skeleton +from sleap.util import json_loads +from sleap.io.video import Video + +from sleap.instance import ( + LabeledFrame, + PredictedPoint, + PredictedInstance, + Track, + Point, + Instance, +) +from sleap.skeleton import Skeleton -from ..nn.tracking import Track def load_predicted_labels_json_old( - data_path: str, parsed_json: dict = None, - adjust_matlab_indexing: bool = True, - fix_rel_paths: bool = True) -> Labels: + data_path: str, + parsed_json: dict = None, + adjust_matlab_indexing: bool = True, + fix_rel_paths: bool = True, +) -> List[LabeledFrame]: """ - Simple utitlity code to load data from Talmo's old JSON format into newer - Labels object. This loads the prediced instances + Load predicted instances from Talmo's old JSON format. Args: data_path: The path to the JSON file. - parsed_json: The parsed json if already loaded. Save some time if already parsed. - adjust_matlab_indexing: Do we need to adjust indexing from MATLAB. - fix_rel_paths: Fix paths to videos to absolute paths. + parsed_json: The parsed json if already loaded, so we can save + some time if already parsed. + adjust_matlab_indexing: Whether to adjust indexing from MATLAB. + fix_rel_paths: Whether to fix paths to videos to absolute paths. Returns: - A newly constructed Labels object. + List of :class:`LabeledFrame` objects. """ if parsed_json is None: data = json.loads(open(data_path).read()) @@ -53,7 +66,10 @@ def load_predicted_labels_json_old( if adjust_matlab_indexing: edges = np.array(edges) - 1 for (src_idx, dst_idx) in edges: - skeleton.add_edge(data["skeleton"]["nodeNames"][src_idx], data["skeleton"]["nodeNames"][dst_idx]) + skeleton.add_edge( + data["skeleton"]["nodeNames"][src_idx], + data["skeleton"]["nodeNames"][dst_idx], + ) if fix_rel_paths: for i, row in videos.iterrows(): @@ -69,22 +85,174 @@ def load_predicted_labels_json_old( if videos.at[i, "format"] == "media": vid = Video.from_media(videos.at[i, "filepath"]) else: - vid = Video.from_hdf5(filename=videos.at[i, "filepath"], dataset=videos.at[i, "dataset"]) + vid = Video.from_hdf5( + filename=videos.at[i, "filepath"], dataset=videos.at[i, "dataset"] + ) video_objects[videos.at[i, "id"]] = vid - track_ids = predicted_instances['trackId'].values + track_ids = predicted_instances["trackId"].values unique_track_ids = np.unique(track_ids) - spawned_on = {track_id: predicted_instances.loc[predicted_instances['trackId'] == track_id]['frameIdx'].values[0] - for track_id in unique_track_ids} - tracks = {i: Track(name=str(i), spawned_on=spawned_on[i]) - for i in np.unique(predicted_instances['trackId'].values).tolist()} + spawned_on = { + track_id: predicted_instances.loc[predicted_instances["trackId"] == track_id][ + "frameIdx" + ].values[0] + for track_id in unique_track_ids + } + tracks = { + i: Track(name=str(i), spawned_on=spawned_on[i]) + for i in np.unique(predicted_instances["trackId"].values).tolist() + } # A function to get all the instances for a particular video frame def get_frame_predicted_instances(video_id, frame_idx): points = predicted_points - is_in_frame = (points["videoId"] == video_id) & (points["frameIdx"] == frame_idx) + is_in_frame = (points["videoId"] == video_id) & ( + points["frameIdx"] == frame_idx + ) + if not is_in_frame.any(): + return [] + + instances = [] + frame_instance_ids = np.unique(points["instanceId"][is_in_frame]) + for i, instance_id in enumerate(frame_instance_ids): + is_instance = is_in_frame & (points["instanceId"] == instance_id) + track_id = predicted_instances.loc[ + predicted_instances["id"] == instance_id + ]["trackId"].values[0] + match_score = predicted_instances.loc[ + predicted_instances["id"] == instance_id + ]["matching_score"].values[0] + track_score = predicted_instances.loc[ + predicted_instances["id"] == instance_id + ]["tracking_score"].values[0] + instance_points = { + data["skeleton"]["nodeNames"][n]: PredictedPoint( + x, y, visible=v, score=confidence + ) + for x, y, n, v, confidence in zip( + *[ + points[k][is_instance] + for k in ["x", "y", "node", "visible", "confidence"] + ] + ) + } + + instance = PredictedInstance( + skeleton=skeleton, + points=instance_points, + track=tracks[track_id], + score=match_score, + ) + instances.append(instance) + + return instances + + # Get the unique labeled frames and construct a list of LabeledFrame objects for them. + frame_keys = list( + { + (videoId, frameIdx) + for videoId, frameIdx in zip( + predicted_points["videoId"], predicted_points["frameIdx"] + ) + } + ) + frame_keys.sort() + labels = [] + for videoId, frameIdx in frame_keys: + label = LabeledFrame( + video=video_objects[videoId], + frame_idx=frameIdx, + instances=get_frame_predicted_instances(videoId, frameIdx), + ) + labels.append(label) + + return labels + + +def load_labels_json_old( + data_path: str, + parsed_json: dict = None, + adjust_matlab_indexing: bool = True, + fix_rel_paths: bool = True, +) -> List[LabeledFrame]: + """ + Load predicted instances from Talmo's old JSON format. + + Args: + data_path: The path to the JSON file. + parsed_json: The parsed json if already loaded, so we can save + some time if already parsed. + adjust_matlab_indexing: Whether to adjust indexing from MATLAB. + fix_rel_paths: Whether to fix paths to videos to absolute paths. + + Returns: + A newly constructed Labels object. + """ + if parsed_json is None: + data = json_loads(open(data_path).read()) + else: + data = parsed_json + + videos = pd.DataFrame(data["videos"]) + instances = pd.DataFrame(data["instances"]) + points = pd.DataFrame(data["points"]) + predicted_instances = pd.DataFrame(data["predicted_instances"]) + predicted_points = pd.DataFrame(data["predicted_points"]) + + if adjust_matlab_indexing: + instances.frameIdx -= 1 + points.frameIdx -= 1 + predicted_instances.frameIdx -= 1 + predicted_points.frameIdx -= 1 + + points.node -= 1 + predicted_points.node -= 1 + + points.x -= 1 + predicted_points.x -= 1 + + points.y -= 1 + predicted_points.y -= 1 + + skeleton = Skeleton() + skeleton.add_nodes(data["skeleton"]["nodeNames"]) + edges = data["skeleton"]["edges"] + if adjust_matlab_indexing: + edges = np.array(edges) - 1 + for (src_idx, dst_idx) in edges: + skeleton.add_edge( + data["skeleton"]["nodeNames"][src_idx], + data["skeleton"]["nodeNames"][dst_idx], + ) + + if fix_rel_paths: + for i, row in videos.iterrows(): + p = row.filepath + if not os.path.exists(p): + p = os.path.join(os.path.dirname(data_path), p) + if os.path.exists(p): + videos.at[i, "filepath"] = p + + # Make the video objects + video_objects = {} + for i, row in videos.iterrows(): + if videos.at[i, "format"] == "media": + vid = Video.from_media(videos.at[i, "filepath"]) + else: + vid = Video.from_hdf5( + filename=videos.at[i, "filepath"], dataset=videos.at[i, "dataset"] + ) + + video_objects[videos.at[i, "id"]] = vid + + # A function to get all the instances for a particular video frame + def get_frame_instances(video_id, frame_idx): + """ """ + is_in_frame = (points["videoId"] == video_id) & ( + points["frameIdx"] == frame_idx + ) if not is_in_frame.any(): return [] @@ -92,28 +260,33 @@ def get_frame_predicted_instances(video_id, frame_idx): frame_instance_ids = np.unique(points["instanceId"][is_in_frame]) for i, instance_id in enumerate(frame_instance_ids): is_instance = is_in_frame & (points["instanceId"] == instance_id) - track_id = predicted_instances.loc[predicted_instances['id'] == instance_id]['trackId'].values[0] - match_score = predicted_instances.loc[predicted_instances['id'] == instance_id]['matching_score'].values[0] - track_score = predicted_instances.loc[predicted_instances['id'] == instance_id]['tracking_score'].values[0] - instance_points = {data["skeleton"]["nodeNames"][n]: PredictedPoint(x, y, visible=v, score=confidence) - for x, y, n, v, confidence in - zip(*[points[k][is_instance] for k in ["x", "y", "node", "visible", "confidence"]])} - - instance = PredictedInstance(skeleton=skeleton, - points=instance_points, - track=tracks[track_id], - score=match_score) + instance_points = { + data["skeleton"]["nodeNames"][n]: Point(x, y, visible=v) + for x, y, n, v in zip( + *[points[k][is_instance] for k in ["x", "y", "node", "visible"]] + ) + } + + instance = Instance(skeleton=skeleton, points=instance_points) instances.append(instance) return instances # Get the unique labeled frames and construct a list of LabeledFrame objects for them. - frame_keys = list({(videoId, frameIdx) for videoId, frameIdx in zip(predicted_points["videoId"], predicted_points["frameIdx"])}) + frame_keys = list( + { + (videoId, frameIdx) + for videoId, frameIdx in zip(points["videoId"], points["frameIdx"]) + } + ) frame_keys.sort() labels = [] for videoId, frameIdx in frame_keys: - label = LabeledFrame(video=video_objects[videoId], frame_idx=frameIdx, - instances = get_frame_predicted_instances(videoId, frameIdx)) + label = LabeledFrame( + video=video_objects[videoId], + frame_idx=frameIdx, + instances=get_frame_instances(videoId, frameIdx), + ) labels.append(label) - return Labels(labels) \ No newline at end of file + return labels diff --git a/sleap/io/video.py b/sleap/io/video.py index 5caad1294..556e921de 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -9,22 +9,26 @@ import numpy as np import attr import cattr +import logging -from typing import Iterable, Union, List +from typing import Iterable, Union, List, Tuple + +logger = logging.getLogger(__name__) @attr.s(auto_attribs=True, cmp=False) class HDF5Video: """ - Video data stored as 4D datasets in HDF5 files can be imported into - the sLEAP system with this class. + Video data stored as 4D datasets in HDF5 files. Args: - filename: The name of the HDF5 file where the dataset with video data is stored. + filename: The name of the HDF5 file where the dataset with video data + is stored. dataset: The name of the HDF5 dataset where the video data is stored. file_h5: The h5.File object that the underlying dataset is stored. dataset_h5: The h5.Dataset object that the underlying data is stored. - input_format: A string value equal to either "channels_last" or "channels_first". + input_format: A string value equal to either "channels_last" or + "channels_first". This specifies whether the underlying video data is stored as: * "channels_first": shape = (frames, channels, width, height) @@ -38,6 +42,9 @@ class HDF5Video: convert_range: bool = attr.ib(default=True) def __attrs_post_init__(self): + """Called by attrs after __init__().""" + + self.__original_to_current_frame_idx = dict() # Handle cases where the user feeds in h5.File objects instead of filename if isinstance(self.filename, h5.File): @@ -45,9 +52,11 @@ def __attrs_post_init__(self): self.filename = self.__file_h5.filename elif type(self.filename) is str: try: - self.__file_h5 = h5.File(self.filename, 'r') + self.__file_h5 = h5.File(self.filename, "r") except OSError as ex: - raise FileNotFoundError(f"Could not find HDF5 file {self.filename}") from ex + raise FileNotFoundError( + f"Could not find HDF5 file {self.filename}" + ) from ex else: self.__file_h5 = None @@ -56,14 +65,27 @@ def __attrs_post_init__(self): self.__dataset_h5 = self.dataset self.__file_h5 = self.__dataset_h5.file self.dataset = self.__dataset_h5.name - elif self.dataset is not None and type(self.dataset) is str: + + # File loaded and dataset name given, so load dataset + elif isinstance(self.dataset, str) and (self.__file_h5 is not None): self.__dataset_h5 = self.__file_h5[self.dataset] + + # Check for frame_numbers dataset corresponding to video + base_dataset_path = "/".join(self.dataset.split("/")[:-1]) + framenum_dataset = f"{base_dataset_path}/frame_numbers" + if framenum_dataset in self.__file_h5: + original_idx_lists = self.__file_h5[framenum_dataset] + # Create map from idx in original video to idx in current + for current_idx in range(len(original_idx_lists)): + original_idx = original_idx_lists[current_idx] + self.__original_to_current_frame_idx[original_idx] = current_idx + else: self.__dataset_h5 = None - @input_format.validator def check(self, attribute, value): + """Called by attrs to validates input format.""" if value not in ["channels_first", "channels_last"]: raise ValueError(f"HDF5Video input_format={value} invalid.") @@ -76,45 +98,85 @@ def check(self, attribute, value): self.__width_idx = 2 self.__height_idx = 1 - def matches(self, other): + def matches(self, other: "HDF5Video") -> bool: """ - Check if attributes match. + Check if attributes match those of another video. Args: - other: The instance to compare with. + other: The other video to compare with. Returns: - True if attributes match, False otherwise + True if attributes match, False otherwise. """ - return self.filename == other.filename and \ - self.dataset == other.dataset and \ - self.convert_range == other.convert_range and \ - self.input_format == other.input_format + return ( + self.filename == other.filename + and self.dataset == other.dataset + and self.convert_range == other.convert_range + and self.input_format == other.input_format + ) + + def close(self): + """Closes the HDF5 file object (if it's open).""" + if self.__file_h5: + try: + self.__file_h5.close() + except: + pass + self.__file_h5 = None + + def __del__(self): + """Releases file object.""" + self.close() # The properties and methods below complete our contract with the # higher level Video interface. @property def frames(self): + """See :class:`Video`.""" return self.__dataset_h5.shape[0] @property def channels(self): + """See :class:`Video`.""" + if "channels" in self.__dataset_h5.attrs: + return int(self.__dataset_h5.attrs["channels"]) return self.__dataset_h5.shape[self.__channel_idx] @property def width(self): + """See :class:`Video`.""" + if "width" in self.__dataset_h5.attrs: + return int(self.__dataset_h5.attrs["width"]) return self.__dataset_h5.shape[self.__width_idx] @property def height(self): + """See :class:`Video`.""" + if "height" in self.__dataset_h5.attrs: + return int(self.__dataset_h5.attrs["height"]) return self.__dataset_h5.shape[self.__height_idx] @property def dtype(self): + """See :class:`Video`.""" return self.__dataset_h5.dtype - def get_frame(self, idx):# -> np.ndarray: + @property + def last_frame_idx(self) -> int: + """ + The idx number of the last frame. + + Overrides method of base :class:`Video` class for videos with + select frames indexed by number from original video, since the last + frame index here will not match the number of frames in video. + """ + if self.__original_to_current_frame_idx: + last_key = sorted(self.__original_to_current_frame_idx.keys())[-1] + return last_key + return self.frames - 1 + + def get_frame(self, idx) -> np.ndarray: """ Get a frame from the underlying HDF5 video data. @@ -124,12 +186,26 @@ def get_frame(self, idx):# -> np.ndarray: Returns: The numpy.ndarray representing the video frame data. """ + # If we only saved some frames from a video, map to idx in dataset. + if self.__original_to_current_frame_idx: + if idx in self.__original_to_current_frame_idx: + idx = self.__original_to_current_frame_idx[idx] + else: + raise ValueError(f"Frame index {idx} not in original index.") + frame = self.__dataset_h5[idx] + if self.__dataset_h5.attrs.get("format", ""): + frame = cv2.imdecode(frame, cv2.IMREAD_UNCHANGED) + + # Add dimension for single channel (dropped by opencv). + if frame.ndim == 2: + frame = frame[..., np.newaxis] + if self.input_format == "channels_first": frame = np.transpose(frame, (2, 1, 0)) - if self.convert_range and np.max(frame) <= 1.: + if self.convert_range and np.max(frame) <= 1.0: frame = (frame * 255).astype(int) return frame @@ -138,17 +214,19 @@ def get_frame(self, idx):# -> np.ndarray: @attr.s(auto_attribs=True, cmp=False) class MediaVideo: """ - Video data stored in traditional media formats readable by FFMPEG can be loaded - with this class. This class provides bare minimum read only interface on top of + Video data stored in traditional media formats readable by FFMPEG + + This class provides bare minimum read only interface on top of OpenCV's VideoCapture class. Args: filename: The name of the file (.mp4, .avi, etc) grayscale: Whether the video is grayscale or not. "auto" means detect - based on first frame. + based on first frame. + bgr: Whether color channels ordered as (blue, green, red). """ + filename: str = attr.ib() - # grayscale: bool = attr.ib(default=None, converter=bool) grayscale: bool = attr.ib() bgr: bool = attr.ib(default=True) _detect_grayscale = False @@ -165,7 +243,9 @@ def __reader(self): # Load if not already loaded if self._reader_ is None: if not os.path.isfile(self.filename): - raise FileNotFoundError(f"Could not find filename video filename named {self.filename}") + raise FileNotFoundError( + f"Could not find filename video filename named {self.filename}" + ) # Try and open the file either locally in current directory or with full path self._reader_ = cv2.VideoCapture(self.filename) @@ -173,7 +253,9 @@ def __reader(self): # If the user specified None for grayscale bool, figure it out based on the # the first frame of data. if self._detect_grayscale is True: - self.grayscale = bool(np.alltrue(self.__test_frame[..., 0] == self.__test_frame[..., -1])) + self.grayscale = bool( + np.alltrue(self.__test_frame[..., 0] == self.__test_frame[..., -1]) + ) # Return cached reader return self._reader_ @@ -188,23 +270,25 @@ def __test_frame(self): # Return stored test frame return self._test_frame_ - def matches(self, other): + def matches(self, other: "MediaVideo") -> bool: """ - Check if attributes match. + Check if attributes match those of another video. Args: - other: The instance to compare with. + other: The other video to compare with. Returns: - True if attributes match, False otherwise + True if attributes match, False otherwise. """ - return self.filename == other.filename and \ - self.grayscale == other.grayscale and \ - self.bgr == other.bgr - + return ( + self.filename == other.filename + and self.grayscale == other.grayscale + and self.bgr == other.bgr + ) @property - def fps(self): + def fps(self) -> float: + """Returns frames per second of video.""" return self.__reader.get(cv2.CAP_PROP_FPS) # The properties and methods below complete our contract with the @@ -212,14 +296,17 @@ def fps(self): @property def frames(self): + """See :class:`Video`.""" return int(self.__reader.get(cv2.CAP_PROP_FRAME_COUNT)) @property def frames_float(self): + """See :class:`Video`.""" return self.__reader.get(cv2.CAP_PROP_FRAME_COUNT) @property def channels(self): + """See :class:`Video`.""" if self.grayscale: return 1 else: @@ -227,17 +314,21 @@ def channels(self): @property def width(self): + """See :class:`Video`.""" return self.__test_frame.shape[1] @property def height(self): + """See :class:`Video`.""" return self.__test_frame.shape[0] @property def dtype(self): + """See :class:`Video`.""" return self.__test_frame.dtype - def get_frame(self, idx, grayscale=None): + def get_frame(self, idx: int, grayscale: bool = None) -> np.ndarray: + """See :class:`Video`.""" if self.__reader.get(cv2.CAP_PROP_POS_FRAMES) != idx: self.__reader.set(cv2.CAP_PROP_POS_FRAMES, idx) @@ -247,10 +338,10 @@ def get_frame(self, idx, grayscale=None): grayscale = self.grayscale if grayscale: - frame = frame[...,0][...,None] + frame = frame[..., 0][..., None] if self.bgr: - frame = frame[...,::-1] + frame = frame[..., ::-1] return frame @@ -265,6 +356,7 @@ class NumpyVideo: * numpy data shape: (frames, width, height, channels) """ + filename: attr.ib() def __attrs_post_init__(self): @@ -282,74 +374,90 @@ def __attrs_post_init__(self): try: self.__data = np.load(self.filename) except OSError as ex: - raise FileNotFoundError(f"Could not find filename {self.filename}") from ex + raise FileNotFoundError( + f"Could not find filename {self.filename}" + ) from ex else: self.__data = None # The properties and methods below complete our contract with the # higher level Video interface. - def matches(self, other): + def matches(self, other: "NumpyVideo") -> np.ndarray: """ - Check if attributes match. + Check if attributes match those of another video. Args: - other: The instance to comapare with. + other: The other video to compare with. Returns: - True if attributes match, False otherwise + True if attributes match, False otherwise. """ return np.all(self.__data == other.__data) @property def frames(self): + """See :class:`Video`.""" return self.__data.shape[self.__frame_idx] @property def channels(self): + """See :class:`Video`.""" return self.__data.shape[self.__channel_idx] @property def width(self): + """See :class:`Video`.""" return self.__data.shape[self.__width_idx] @property def height(self): + """See :class:`Video`.""" return self.__data.shape[self.__height_idx] @property def dtype(self): + """See :class:`Video`.""" return self.__data.dtype def get_frame(self, idx): + """See :class:`Video`.""" return self.__data[idx] @attr.s(auto_attribs=True, cmp=False) class ImgStoreVideo: """ - Video data stored as an ImgStore dataset. See: https://github.com/loopbio/imgstore - This class is just a lightweight wrapper for reading such datasets as videos sources - for sLEAP. + Video data stored as an ImgStore dataset. + + See: https://github.com/loopbio/imgstore + This class is just a lightweight wrapper for reading such datasets as + video sources for SLEAP. Args: filename: The name of the file or directory to the imgstore. - index_by_original: ImgStores are great for storing a collection of frame - selected frames from an larger video. If the index_by_original is set to - True than the get_frame function will accept the original frame numbers of - from original video. If False, then it will accept the frame index from the - store directly. + index_by_original: ImgStores are great for storing a collection of + selected frames from an larger video. If the index_by_original is + set to True then the get_frame function will accept the original + frame numbers of from original video. If False, then it will + accept the frame index from the store directly. + Default to True so that we can use an ImgStoreVideo in a dataset + to replace another video without having to update all the frame + indices on :class:`LabeledFrame` objects in the dataset. """ filename: str = attr.ib(default=None) index_by_original: bool = attr.ib(default=True) + _store_ = None + _img_ = None def __attrs_post_init__(self): # If the filename does not contain metadata.yaml, append it to the filename # assuming that this is a directory that contains the imgstore. - if 'metadata.yaml' not in self.filename: - self.filename = os.path.join(self.filename, 'metadata.yaml') + if "metadata.yaml" not in self.filename: + # Use "/" since this works on Windows and posix + self.filename = self.filename + "/metadata.yaml" # Make relative path into absolute, ImgStores don't work properly it seems # without full paths if we change working directories. Video.fixup_path will @@ -357,7 +465,6 @@ def __attrs_post_init__(self): self.filename = os.path.abspath(self.filename) self.__store = None - self.open() # The properties and methods below complete our contract with the # higher level Video interface. @@ -372,14 +479,35 @@ def matches(self, other): Returns: True if attributes match, False otherwise """ - return self.filename == other.filename and self.index_by_original == other.index_by_original + return ( + self.filename == other.filename + and self.index_by_original == other.index_by_original + ) + + @property + def __store(self): + if self._store_ is None: + self.open() + return self._store_ + + @__store.setter + def __store(self, val): + self._store_ = val + + @property + def __img(self): + if self._img_ is None: + self.open() + return self._img_ @property def frames(self): + """See :class:`Video`.""" return self.__store.frame_count @property def channels(self): + """See :class:`Video`.""" if len(self.__img.shape) < 3: return 1 else: @@ -387,38 +515,57 @@ def channels(self): @property def width(self): + """See :class:`Video`.""" return self.__img.shape[1] @property def height(self): + """See :class:`Video`.""" return self.__img.shape[0] @property def dtype(self): + """See :class:`Video`.""" return self.__img.dtype - def get_frame(self, frame_number) -> np.ndarray: + @property + def last_frame_idx(self) -> int: + """ + The idx number of the last frame. + + Overrides method of base :class:`Video` class for videos with + select frames indexed by number from original video, since the last + frame index here will not match the number of frames in video. + """ + if self.index_by_original: + return self.__store.frame_max + return self.frames - 1 + + def get_frame(self, frame_number: int) -> np.ndarray: """ Get a frame from the underlying ImgStore video data. Args: - frame_num: The number of the frame to get. If index_by_original is set to True, - then this number should actually be a frame index withing the imgstore. That is, - if there are 4 frames in the imgstore, this number shoulde be from 0 to 3. + frame_number: The number of the frame to get. If + index_by_original is set to True, then this number should + actually be a frame index within the imgstore. That is, + if there are 4 frames in the imgstore, this number should be + be from 0 to 3. Returns: The numpy.ndarray representing the video frame data. """ # Check if we need to open the imgstore and do it if needed - if not self.imgstore: + if not self._store_: self.open() if self.index_by_original: img, (frame_number, frame_timestamp) = self.__store.get_image(frame_number) else: - img, (frame_number, frame_timestamp) = self.__store.get_image(frame_number=None, - frame_index=frame_number) + img, (frame_number, frame_timestamp) = self.__store.get_image( + frame_number=None, frame_index=frame_number + ) # If the frame has one channel, add a singleton channel as it seems other # video implementations do this. @@ -444,12 +591,12 @@ def open(self): Returns: None """ - if not self.imgstore: + if not self._store_: # Open the imgstore - self.__store = imgstore.new_for_filename(self.filename) + self._store_ = imgstore.new_for_filename(self.filename) # Read a frame so we can compute shape an such - self.__img, (frame_number, frame_timestamp) = self.__store.get_next_image() + self._img_, (frame_number, frame_timestamp) = self._store_.get_next_image() def close(self): """ @@ -467,37 +614,40 @@ def close(self): @attr.s(auto_attribs=True, cmp=False) class Video: """ - The top-level interface to any Video data used by sLEAP is represented by - the :class:`.Video` class. This class provides a common interface for - various supported video data backends. It provides the bare minimum of - properties and methods that any video data needs to support in order to - function with other sLEAP components. This interface currently only supports - reading of video data, there is no write support. Unless one is creating a new video + The top-level interface to any Video data used by SLEAP. + + This class provides a common interface for various supported video data + backends. It provides the bare minimum of properties and methods that + any video data needs to support in order to function with other SLEAP + components. This interface currently only supports reading of video + data, there is no write support. Unless one is creating a new video backend, this class should be instantiated from its various class methods for different formats. For example: - >>> video = Video.from_hdf5(filename='test.h5', dataset='box') - >>> video = Video.from_media(filename='test.mp4') + >>> video = Video.from_hdf5(filename="test.h5", dataset="box") + >>> video = Video.from_media(filename="test.mp4") Or we can use auto-detection based on filename: - >>> video = Video.from_filename(filename='test.mp4') + >>> video = Video.from_filename(filename="test.mp4") Args: - backend: A backend is and object that implements the following basic - required methods and properties + backend: A backend is an object that implements the following basic + required methods and properties * Properties * :code:`frames`: The number of frames in the video - * :code:`channels`: The number of channels in the video (e.g. 1 for grayscale, 3 for RGB) + * :code:`channels`: The number of channels in the video + (e.g. 1 for grayscale, 3 for RGB) * :code:`width`: The width of each frame in pixels * :code:`height`: The height of each frame in pixels * Methods - * :code:`get_frame(frame_index: int) -> np.ndarray(shape=(width, height, channels)`: - Get a single frame from the underlying video data + * :code:`get_frame(frame_index: int) -> np.ndarray`: + Get a single frame from the underlying video data with + output shape=(width, height, channels). """ @@ -509,11 +659,23 @@ def __getattr__(self, item): @property def num_frames(self) -> int: - """The number of frames in the video. Just an alias for frames property.""" + """ + The number of frames in the video. Just an alias for frames property. + """ return self.frames @property - def shape(self): + def last_frame_idx(self) -> int: + """ + The idx number of the last frame. Usually `numframes - 1`. + """ + if hasattr(self.backend, "last_frame_idx"): + return self.backend.last_frame_idx + return self.frames - 1 + + @property + def shape(self) -> Tuple[int, int, int, int]: + """ Returns (frame count, height, width, channels).""" return (self.frames, self.height, self.width, self.channels) def __str__(self): @@ -549,10 +711,11 @@ def get_frames(self, idxs: Union[int, Iterable[int]]) -> np.ndarray: idxs: An iterable object that contains the indices of frames. Returns: - The requested video frames with shape (len(idxs), width, height, channels) + The requested video frames with shape + (len(idxs), width, height, channels) """ if np.isscalar(idxs): - idxs = [idxs,] + idxs = [idxs] return np.stack([self.get_frame(idx) for idx in idxs], axis=0) def __getitem__(self, idxs): @@ -562,19 +725,24 @@ def __getitem__(self, idxs): return self.get_frames(idxs) @classmethod - def from_hdf5(cls, dataset: Union[str, h5.Dataset], - filename: Union[str, h5.File] = None, - input_format: str = "channels_last", - convert_range: bool = True): + def from_hdf5( + cls, + dataset: Union[str, h5.Dataset], + filename: Union[str, h5.File] = None, + input_format: str = "channels_last", + convert_range: bool = True, + ) -> "Video": """ - Create an instance of a video object from an HDF5 file and dataset. This - is a helper method that invokes the HDF5Video backend. + Create an instance of a video object from an HDF5 file and dataset. + + This is a helper method that invokes the HDF5Video backend. Args: - dataset: The name of the dataset or and h5.Dataset object. If filename is - h5.File, dataset must be a str of the dataset name. + dataset: The name of the dataset or and h5.Dataset object. If + filename is h5.File, dataset must be a str of the dataset name. filename: The name of the HDF5 file or and open h5.File object. - input_format: Whether the data is oriented with "channels_first" or "channels_last" + input_format: Whether the data is oriented with "channels_first" + or "channels_last" convert_range: Whether we should convert data to [0, 255]-range Returns: @@ -582,20 +750,22 @@ def from_hdf5(cls, dataset: Union[str, h5.Dataset], """ filename = Video.fixup_path(filename) backend = HDF5Video( - filename=filename, - dataset=dataset, - input_format=input_format, - convert_range=convert_range - ) + filename=filename, + dataset=dataset, + input_format=input_format, + convert_range=convert_range, + ) return cls(backend=backend) @classmethod - def from_numpy(cls, filename, *args, **kwargs): + def from_numpy(cls, filename: str, *args, **kwargs) -> "Video": """ Create an instance of a video object from a numpy array. Args: filename: The numpy array or the name of the file + args: Arguments to pass to :class:`NumpyVideo` + kwargs: Arguments to pass to :class:`NumpyVideo` Returns: A Video object with a NumpyVideo backend @@ -605,12 +775,16 @@ def from_numpy(cls, filename, *args, **kwargs): return cls(backend=backend) @classmethod - def from_media(cls, filename: str, *args, **kwargs): + def from_media(cls, filename: str, *args, **kwargs) -> "Video": """ - Create an instance of a video object from a typical media file (e.g. .mp4, .avi). + Create an instance of a video object from a typical media file. + + For example, mp4, avi, or other types readable by FFMPEG. Args: filename: The name of the file + args: Arguments to pass to :class:`MediaVideo` + kwargs: Arguments to pass to :class:`MediaVideo` Returns: A Video object with a MediaVideo backend @@ -620,20 +794,25 @@ def from_media(cls, filename: str, *args, **kwargs): return cls(backend=backend) @classmethod - def from_filename(cls, filename: str, *args, **kwargs): + def from_filename(cls, filename: str, *args, **kwargs) -> "Video": """ - Create an instance of a video object from a filename, auto-detecting the backend. + Create an instance of a video object, auto-detecting the backend. Args: - filename: The path to the video filename. Currently supported types are: + filename: The path to the video filename. + Currently supported types are: + + * Media Videos - AVI, MP4, etc. handled by OpenCV directly + * HDF5 Datasets - .h5 files + * Numpy Arrays - npy files + * imgstore datasets - produced by loopbio's Motif recording + system. See: https://github.com/loopbio/imgstore. - * Media Videos - AVI, MP4, etc. handled by OpenCV directly - * HDF5 Datasets - .h5 files - * Numpy Arrays - npy files - * imgstore datasets - produced by loopbio's Motif recording system. See: https://github.com/loopbio/imgstore. + args: Arguments to pass to :class:`NumpyVideo` + kwargs: Arguments to pass to :class:`NumpyVideo` Returns: - A Video object with the detected backend + A Video object with the detected backend. """ filename = Video.fixup_path(filename) @@ -650,27 +829,29 @@ def from_filename(cls, filename: str, *args, **kwargs): raise ValueError("Could not detect backend for specified filename.") @classmethod - def imgstore_from_filenames(cls, filenames: list, output_filename: str, *args, **kwargs): - """Create an imagestore from a list of image files. + def imgstore_from_filenames( + cls, filenames: list, output_filename: str, *args, **kwargs + ) -> "Video": + """Create an imgstore from a list of image files. Args: filenames: List of filenames for the image files. - output_filename: Filename for the imagestore to create. + output_filename: Filename for the imgstore to create. Returns: - A `Video` object for the new imagestore. + A `Video` object for the new imgstore. """ # get the image size from the first file first_img = cv2.imread(filenames[0], flags=cv2.IMREAD_COLOR) img_shape = first_img.shape - # create the imagestore - store = imgstore.new_for_format('png', - mode='w', basedir=output_filename, - imgshape=img_shape) + # create the imgstore + store = imgstore.new_for_format( + "png", mode="w", basedir=output_filename, imgshape=img_shape + ) - # read each frame and write it to the imagestore + # read each frame and write it to the imgstore # unfortunately imgstore doesn't let us just add the file for i, img_filename in enumerate(filenames): img = cv2.imread(img_filename, flags=cv2.IMREAD_COLOR) @@ -681,30 +862,33 @@ def imgstore_from_filenames(cls, filenames: list, output_filename: str, *args, * # Return an ImgStoreVideo object referencing this new imgstore. return cls(backend=ImgStoreVideo(filename=output_filename)) - @classmethod - def to_numpy(cls, frame_data: np.array, file_name: str): - np.save(file_name, frame_data, 'w') - - def to_imgstore(self, path, - frame_numbers: List[int] = None, - format: str = "png", - index_by_original: bool = True): + def to_imgstore( + self, + path: str, + frame_numbers: List[int] = None, + format: str = "png", + index_by_original: bool = True, + ) -> "Video": """ - Read frames from an arbitrary video backend and store them in a loopbio imgstore. + Converts frames from arbitrary video backend to ImgStoreVideo. + This should facilitate conversion of any video to a loopbio imgstore. Args: path: Filename or directory name to store imgstore. - frame_numbers: A list of frame numbers from the video to save. If None save - the entire video. - format: By default it will create a DirectoryImgStore with lossless PNG format. - Unless the frame_indices = None, in which case, it will default to 'mjpeg/avi' - format for video. + frame_numbers: A list of frame numbers from the video to save. + If None save the entire video. + format: By default it will create a DirectoryImgStore with lossless + PNG format unless the frame_indices = None, in which case, + it will default to 'mjpeg/avi' format for video. index_by_original: ImgStores are great for storing a collection of - selected frames from an larger video. If the index_by_original is set to - True than the get_frame function will accept the original frame numbers of - from original video. If False, then it will accept the frame index from the - store directly. + selected frames from an larger video. If the index_by_original + is set to True then the get_frame function will accept the + original frame numbers of from original video. If False, + then it will accept the frame index from the store directly. + Default to True so that we can use an ImgStoreVideo in a + dataset to replace another video without having to update + all the frame indices on :class:`LabeledFrame` objects in the dataset. Returns: A new Video object that references the imgstore. @@ -730,28 +914,123 @@ def to_imgstore(self, path, # new_backend = self.backend.copy_to(path) # return self.__class__(backend=new_backend) - store = imgstore.new_for_format(format, - mode='w', basedir=path, - imgshape=(self.shape[1], self.shape[2], self.shape[3]), - chunksize=1000) + store = imgstore.new_for_format( + format, + mode="w", + basedir=path, + imgshape=(self.height, self.width, self.channels), + chunksize=1000, + ) # Write the JSON for the original video object to the metadata # of the imgstore for posterity store.add_extra_data(source_sleap_video_obj=Video.cattr().unstructure(self)) import time + for frame_num in frame_numbers: store.add_image(self.get_frame(frame_num), frame_num, time.time()) + # If there are no frames to save for this video, add a dummy frame + # since we can't save an empty imgstore. + if len(frame_numbers) == 0: + store.add_image( + np.zeros((self.height, self.width, self.channels)), 0, time.time() + ) + store.close() # Return an ImgStoreVideo object referencing this new imgstore. - return self.__class__(backend=ImgStoreVideo(filename=path, index_by_original=index_by_original)) + return self.__class__( + backend=ImgStoreVideo(filename=path, index_by_original=index_by_original) + ) + + def to_hdf5( + self, + path: str, + dataset: str, + frame_numbers: List[int] = None, + format: str = "", + index_by_original: bool = True, + ): + """ + Converts frames from arbitrary video backend to HDF5Video. + + Used for building an HDF5 that holds all data needed for training. + + Args: + path: Filename to HDF5 (which could already exist). + dataset: The HDF5 dataset in which to store video frames. + frame_numbers: A list of frame numbers from the video to save. + If None save the entire video. + format: If non-empty, then encode images in format before saving. + Otherwise, save numpy matrix of frames. + index_by_original: If the index_by_original is set to True then + the get_frame function will accept the original frame + numbers of from original video. + If False, then it will accept the frame index directly. + Default to True so that we can use resulting video in a + dataset to replace another video without having to update + all the frame indices in the dataset. + + Returns: + A new Video object that references the HDF5 dataset. + """ + + # If the user has not provided a list of frames to store, store them all. + if frame_numbers is None: + frame_numbers = range(self.num_frames) + + if frame_numbers: + frame_data = self.get_frames(frame_numbers) + else: + frame_data = np.zeros((1, 1, 1, 1)) + + frame_numbers_data = np.array(list(frame_numbers), dtype=int) + + with h5.File(path, "a") as f: + + if format: + + def encode(img): + _, encoded = cv2.imencode("." + format, img) + return np.squeeze(encoded) + + dtype = h5.special_dtype(vlen=np.dtype("int8")) + dset = f.create_dataset( + dataset + "/video", (len(frame_numbers),), dtype=dtype + ) + dset.attrs["format"] = format + dset.attrs["channels"] = self.channels + dset.attrs["height"] = self.height + dset.attrs["width"] = self.width + + for i in range(len(frame_numbers)): + dset[i] = encode(frame_data[i]) + else: + f.create_dataset( + dataset + "/video", + data=frame_data, + compression="gzip", + compression_opts=9, + ) + + if index_by_original: + f.create_dataset(dataset + "/frame_numbers", data=frame_numbers_data) + + return self.__class__( + backend=HDF5Video( + filename=path, + dataset=dataset + "/video", + input_format="channels_last", + convert_range=False, + ) + ) @staticmethod def cattr(): """ - Return a cattr converter for serialiazing\deseriializing Video objects. + Returns a cattr converter for serialiazing/deserializing Video objects. Returns: A cattr converter. @@ -760,10 +1039,10 @@ def cattr(): # When we are structuring video backends, try to fixup the video file paths # in case they are coming from a different computer or the file has been moved. def fixup_video(x, cl): - if 'filename' in x: - x['filename'] = Video.fixup_path(x['filename']) - if 'file' in x: - x['file'] = Video.fixup_path(x['file']) + if "filename" in x: + x["filename"] = Video.fixup_path(x["filename"]) + if "file" in x: + x["file"] = Video.fixup_path(x["file"]) return cl(**x) @@ -777,17 +1056,29 @@ def fixup_video(x, cl): return vid_cattr @staticmethod - def fixup_path(path, raise_error=False) -> str: + def fixup_path(path: str, raise_error: bool = False) -> str: """ - Given a path to a video try to find it. This is attempt to make the paths - serialized for different video objects portabls across multiple computers. - The default behaviour is to store whatever path is stored on the backend - object. If this is an absolute path it is almost certainly wrong when - transfered when the object is created on another computer. We try to - find the video by looking in the current working directory as well. + Tries to locate video if the given path doesn't work. + + Given a path to a video try to find it. This is attempt to make the + paths serialized for different video objects portable across multiple + computers. The default behavior is to store whatever path is stored + on the backend object. If this is an absolute path it is almost + certainly wrong when transferred when the object is created on + another computer. We try to find the video by looking in the current + working directory as well. + + Note that when loading videos during the process of deserializing a + saved :class:`Labels` dataset, it's usually preferable to fix video + paths using a `video_callback`. Args: path: The path the video asset. + raise_error: Whether to raise error if we cannot find video. + + Raises: + FileNotFoundError: If file still cannot be found and raise_error + is True. Returns: The fixed up path @@ -808,7 +1099,7 @@ def fixup_path(path, raise_error=False) -> str: # Special case: this is an ImgStore path! We cant use # basename because it will strip the directory name off - elif path.endswith('metadata.yaml'): + elif path.endswith("metadata.yaml"): # Get the parent dir of the YAML file. img_store_dir = os.path.basename(os.path.split(path)[0]) @@ -819,6 +1110,5 @@ def fixup_path(path, raise_error=False) -> str: if raise_error: raise FileNotFoundError(f"Cannot find a video file: {path}") else: - print(f"Cannot find a video file: {path}") + logger.warning(f"Cannot find a video file: {path}") return path - diff --git a/sleap/io/visuals.py b/sleap/io/visuals.py index 1949ed58a..a60e4bc28 100644 --- a/sleap/io/visuals.py +++ b/sleap/io/visuals.py @@ -1,3 +1,7 @@ +""" +Module for generating videos with visual annotation overlays. +""" + from sleap.io.video import Video from sleap.io.dataset import Labels from sleap.util import usable_cpu_count @@ -7,32 +11,37 @@ import numpy as np import math from time import time, clock -from typing import List +from typing import List, Tuple from queue import Queue from threading import Thread import logging + logger = logging.getLogger(__name__) # Object that signals shutdown _sentinel = object() + def reader(out_q: Queue, video: Video, frames: List[int]): """Read frame images from video and send them into queue. Args: out_q: Queue to send (list of frame indexes, ndarray of frame images) for chunks of video. - video: the `Video` object to read - frames: full list frame indexes we want to read + video: The `Video` object to read. + frames: Full list frame indexes we want to read. + + Returns: + None. """ cv2.setNumThreads(usable_cpu_count()) total_count = len(frames) chunk_size = 64 - chunk_count = math.ceil(total_count/chunk_size) + chunk_count = math.ceil(total_count / chunk_size) logger.info(f"Chunks: {chunk_count}, chunk size: {chunk_size}") @@ -50,7 +59,7 @@ def reader(out_q: Queue, video: Video, frames: List[int]): video_frame_images = video[frames_idx_chunk] elapsed = clock() - t0 - fps = len(frames_idx_chunk)/elapsed + fps = len(frames_idx_chunk) / elapsed logger.debug(f"reading chunk {i} in {elapsed} s = {fps} fps") i += 1 @@ -59,14 +68,16 @@ def reader(out_q: Queue, video: Video, frames: List[int]): # send _sentinal object into queue to signal that we're done out_q.put(_sentinel) + def marker(in_q: Queue, out_q: Queue, labels: Labels, video_idx: int): """Annotate frame images (draw instances). Args: - in_q: Queue with (list of frame indexes, ndarray of frame images) - out_q: Queue to send annotated images as (images, h, w, channels) ndarray + in_q: Queue with (list of frame indexes, ndarray of frame images). + out_q: Queue to send annotated images as + (images, h, w, channels) ndarray. labels: the `Labels` object from which to get data for annotating. - video_idx: index of `Video` in `labels.videos` list + video_idx: index of `Video` in `labels.videos` list. Returns: None. @@ -89,14 +100,15 @@ def marker(in_q: Queue, out_q: Queue, labels: Labels, video_idx: int): imgs = [] for i, frame_idx in enumerate(frames_idx_chunk): img = get_frame_image( - video_frame=video_frame_images[i], - video_idx=video_idx, - frame_idx=frame_idx, - labels=labels) + video_frame=video_frame_images[i], + video_idx=video_idx, + frame_idx=frame_idx, + labels=labels, + ) imgs.append(img) elapsed = clock() - t0 - fps = len(imgs)/elapsed + fps = len(imgs) / elapsed logger.debug(f"drawing chunk {chunk_i} in {elapsed} s = {fps} fps") chunk_i += 1 out_q.put(imgs) @@ -104,8 +116,14 @@ def marker(in_q: Queue, out_q: Queue, labels: Labels, video_idx: int): # send _sentinal object into queue to signal that we're done out_q.put(_sentinel) -def writer(in_q: Queue, progress_queue: Queue, - filename: str, fps: int, img_w_h: tuple): + +def writer( + in_q: Queue, + progress_queue: Queue, + filename: str, + fps: float, + img_w_h: Tuple[int, int], +): """Write annotated images to video. Args: @@ -123,7 +141,7 @@ def writer(in_q: Queue, progress_queue: Queue, cv2.setNumThreads(usable_cpu_count()) - fourcc = cv2.VideoWriter_fourcc(*'MJPG') + fourcc = cv2.VideoWriter_fourcc(*"MJPG") out = cv2.VideoWriter(filename, fourcc, fps, img_w_h) start_time = clock() @@ -143,7 +161,7 @@ def writer(in_q: Queue, progress_queue: Queue, out.write(img) elapsed = clock() - t0 - fps = len(data)/elapsed + fps = len(data) / elapsed logger.debug(f"writing chunk {i} in {elapsed} s = {fps} fps") i += 1 @@ -155,14 +173,28 @@ def writer(in_q: Queue, progress_queue: Queue, # send (-1, time) to signal done progress_queue.put((-1, total_elapsed)) + def save_labeled_video( - filename: str, - labels: Labels, - video: Video, - frames: List[int], - fps: int=15, - gui_progress: bool=False): - """Function to generate and save video with annotations.""" + filename: str, + labels: Labels, + video: Video, + frames: List[int], + fps: int = 15, + gui_progress: bool = False, +): + """Function to generate and save video with annotations. + + Args: + filename: Output filename. + labels: The dataset from which to get data. + video: The source :class:`Video` we want to annotate. + frames: List of frames to include in output video. + fps: Frames per second for output video. + gui_progress: Whether to show Qt GUI progress dialog. + + Returns: + None. + """ output_size = (video.height, video.width) print(f"Writing video with {len(frames)} frame images...") @@ -173,12 +205,14 @@ def save_labeled_video( q2 = Queue() progress_queue = Queue() - thread_read = Thread(target=reader, args=(q1, video, frames,)) - thread_mark = Thread(target=marker, args=(q1, q2, labels, labels.videos.index(video))) - thread_write = Thread(target=writer, args=( - q2, progress_queue, filename, - fps, (video.width, video.height), - )) + thread_read = Thread(target=reader, args=(q1, video, frames)) + thread_mark = Thread( + target=marker, args=(q1, q2, labels, labels.videos.index(video)) + ) + thread_write = Thread( + target=writer, + args=(q2, progress_queue, filename, fps, (video.width, video.height)), + ) thread_read.start() thread_mark.start() @@ -189,9 +223,8 @@ def save_labeled_video( from PySide2 import QtWidgets, QtCore progress_win = QtWidgets.QProgressDialog( - f"Generating video with {len(frames)} frames...", - "Cancel", - 0, len(frames)) + f"Generating video with {len(frames)} frames...", "Cancel", 0, len(frames) + ) progress_win.setMinimumWidth(300) progress_win.setWindowModality(QtCore.Qt.WindowModal) @@ -201,20 +234,24 @@ def save_labeled_video( break if progress_win is not None and progress_win.wasCanceled(): break - fps = frames_complete/elapsed + fps = frames_complete / elapsed remaining_frames = len(frames) - frames_complete - remaining_time = remaining_frames/fps + remaining_time = remaining_frames / fps if gui_progress: progress_win.setValue(frames_complete) else: - print(f"Finished {frames_complete} frames in {elapsed} s, fps = {fps}, approx {remaining_time} s remaining") + print( + f"Finished {frames_complete} frames in {elapsed} s, fps = {fps}, approx {remaining_time} s remaining" + ) elapsed = clock() - t0 - fps = len(frames)/elapsed + fps = len(frames) / elapsed print(f"Done in {elapsed} s, fps = {fps}.") -def img_to_cv(img): + +def img_to_cv(img: np.ndarray) -> np.ndarray: + """Prepares frame image as needed for opencv.""" # Convert RGB to BGR for OpenCV if img.shape[-1] == 3: img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) @@ -223,27 +260,58 @@ def img_to_cv(img): img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) return img -def get_frame_image(video_frame, video_idx, frame_idx, labels): + +def get_frame_image( + video_frame: np.ndarray, video_idx: int, frame_idx: int, labels: Labels +) -> np.ndarray: + """Returns single annotated frame image. + + Args: + video_frame: The ndarray of the frame image. + video_idx: Index of video in :attribute:`Labels.videos` list. + frame_idx: Index of frame in video. + labels: The dataset from which to get data. + + Returns: + ndarray of frame image with visual annotations added. + """ img = img_to_cv(video_frame) plot_instances_cv(img, video_idx, frame_idx, labels) return img + def _point_int_tuple(point): + """Returns (x, y) tuple from :class:`Point`.""" return int(point.x), int(point.y) -def plot_instances_cv(img, video_idx, frame_idx, labels): - cmap = ([ - [0, 114, 189], - [217, 83, 25], - [237, 177, 32], - [126, 47, 142], - [119, 172, 48], - [77, 190, 238], - [162, 20, 47], - ]) + +def plot_instances_cv( + img: np.ndarray, video_idx: int, frame_idx: int, labels: Labels +) -> np.ndarray: + """Adds visuals annotations to single frame image. + + Args: + img: The ndarray of the frame image. + video_idx: Index of video in :attribute:`Labels.videos` list. + frame_idx: Index of frame in video. + labels: The dataset from which to get data. + + Returns: + ndarray of frame image with visual annotations added. + """ + cmap = [ + [0, 114, 189], + [217, 83, 25], + [237, 177, 32], + [126, 47, 142], + [119, 172, 48], + [77, 190, 238], + [162, 20, 47], + ] lfs = labels.find(labels.videos[video_idx], frame_idx) - if len(lfs) == 0: return + if len(lfs) == 0: + return count_no_track = 0 for i, instance in enumerate(lfs[0].instances_to_show): @@ -254,11 +322,29 @@ def plot_instances_cv(img, video_idx, frame_idx, labels): track_idx = len(labels.tracks) + count_no_track count_no_track += 1 - inst_color = cmap[track_idx%len(cmap)] + inst_color = cmap[track_idx % len(cmap)] plot_instance_cv(img, instance, inst_color) -def plot_instance_cv(img, instance, color, marker_radius=4): + +def plot_instance_cv( + img: np.ndarray, + instance: "Instance", + color: Tuple[int, int, int], + marker_radius: float = 4, +) -> np.ndarray: + """ + Add visual annotations for single instance. + + Args: + img: The ndarray of the frame image. + instance: The :class:`Instance` to add to frame image. + color: (r, g, b) color for this instance. + marker_radius: Radius of marker for instance points (nodes). + + Returns: + ndarray of frame image with visual annotations for instance added. + """ # RGB -> BGR for cv2 cv_color = color[::-1] @@ -266,23 +352,31 @@ def plot_instance_cv(img, instance, color, marker_radius=4): for (node, point) in instance.nodes_points: # plot node at point if point.visible and not point.isnan(): - cv2.circle(img, - _point_int_tuple(point), - marker_radius, - cv_color, - lineType=cv2.LINE_AA) + cv2.circle( + img, + _point_int_tuple(point), + marker_radius, + cv_color, + lineType=cv2.LINE_AA, + ) for (src, dst) in instance.skeleton.edges: # Make sure that both nodes are present in this instance before drawing edge if src in instance and dst in instance: - if instance[src].visible and instance[dst].visible \ - and not instance[src].isnan() and not instance[dst].isnan(): + if ( + instance[src].visible + and instance[dst].visible + and not instance[src].isnan() + and not instance[dst].isnan() + ): cv2.line( - img, - _point_int_tuple(instance[src]), - _point_int_tuple(instance[dst]), - cv_color, - lineType=cv2.LINE_AA) + img, + _point_int_tuple(instance[src]), + _point_int_tuple(instance[dst]), + cv_color, + lineType=cv2.LINE_AA, + ) + if __name__ == "__main__": @@ -291,13 +385,21 @@ def plot_instance_cv(img, instance, color, marker_radius=4): parser = argparse.ArgumentParser() parser.add_argument("data_path", help="Path to labels json file") - parser.add_argument('-o', '--output', type=str, default=None, - help='The output filename for the video') - parser.add_argument('-f', '--fps', type=int, default=15, - help='Frames per second') - parser.add_argument('--frames', type=frame_list, default="", - help='list of frames to predict. Either comma separated list (e.g. 1,2,3) or ' - 'a range separated by hyphen (e.g. 1-3). (default is entire video)') + parser.add_argument( + "-o", + "--output", + type=str, + default=None, + help="The output filename for the video", + ) + parser.add_argument("-f", "--fps", type=int, default=15, help="Frames per second") + parser.add_argument( + "--frames", + type=frame_list, + default="", + help="list of frames to predict. Either comma separated list (e.g. 1,2,3) or " + "a range separated by hyphen (e.g. 1-3). (default is entire video)", + ) args = parser.parse_args() video_callback = Labels.make_video_callback([os.path.dirname(args.data_path)]) @@ -312,10 +414,12 @@ def plot_instance_cv(img, instance, color, marker_radius=4): filename = args.output or args.data_path + ".avi" - save_labeled_video(filename=filename, - labels=labels, - video=labels.videos[0], - frames=frames, - fps=args.fps) + save_labeled_video( + filename=filename, + labels=labels, + video=labels.videos[0], + frames=frames, + fps=args.fps, + ) print(f"Video saved as: {filename}") diff --git a/sleap/nn/architectures/__init__.py b/sleap/nn/architectures/__init__.py index ba2d97b94..8b653eb77 100644 --- a/sleap/nn/architectures/__init__.py +++ b/sleap/nn/architectures/__init__.py @@ -1,8 +1,14 @@ from sleap.nn.architectures.leap import LeapCNN from sleap.nn.architectures.unet import UNet, StackedUNet from sleap.nn.architectures.hourglass import StackedHourglass +from sleap.nn.architectures.resnet import ResNet50 +from typing import TypeVar # TODO: We can set this up to find all classes under sleap.nn.architectures -available_archs = [LeapCNN, UNet, StackedUNet, StackedHourglass] +available_archs = [LeapCNN, UNet, StackedUNet, StackedHourglass, ResNet50] +available_arch_names = [arch.__name__ for arch in available_archs] +BackboneType = TypeVar("BackboneType", *available_archs) -__all__ = ['available_archs'] + [arch.__name__ for arch in available_archs] +__all__ = ["available_archs", "available_arch_names", "BackboneType"] + [ + arch.__name__ for arch in available_archs +] diff --git a/sleap/nn/architectures/common.py b/sleap/nn/architectures/common.py index 61c747332..fb0804ebd 100644 --- a/sleap/nn/architectures/common.py +++ b/sleap/nn/architectures/common.py @@ -4,6 +4,7 @@ from keras.layers import Conv2D, BatchNormalization, Add + def expand_to_n(x, n): """Expands an object `x` to `n` elements if scalar. @@ -18,15 +19,14 @@ def expand_to_n(x, n): """ if not isinstance(x, (collections.Sequence, np.ndarray)): - x = [x,] - + x = [x] + if np.size(x) == 1: x = np.tile(x, n) elif np.size(x) != n: raise ValueError("Variable to expand must be scalar.") - - return x + return x def conv(num_filters, kernel_size=(3, 3), activation="relu", **kwargs): @@ -41,7 +41,14 @@ def conv(num_filters, kernel_size=(3, 3), activation="relu", **kwargs): Returns: keras.layers.Conv2D instance built with presets """ - return Conv2D(num_filters, kernel_size=kernel_size, activation=activation, padding="same", **kwargs) + return Conv2D( + num_filters, + kernel_size=kernel_size, + activation=activation, + padding="same", + **kwargs + ) + def conv1(num_filters, **kwargs): """Convenience presets for 1x1 Conv2D. @@ -55,6 +62,7 @@ def conv1(num_filters, **kwargs): """ return conv(num_filters, kernel_size=(1, 1), **kwargs) + def conv3(num_filters, **kwargs): """Convenience presets for 3x3 Conv2D. @@ -67,6 +75,7 @@ def conv3(num_filters, **kwargs): """ return conv(num_filters, kernel_size=(3, 3), **kwargs) + def residual_block(x_in, num_filters=None, batch_norm=True): """Residual bottleneck block. @@ -99,26 +108,31 @@ def residual_block(x_in, num_filters=None, batch_norm=True): # Default to output the same number of channels as input if num_filters is None: num_filters = x_in.shape[-1] - + # Number of output channels must be divisible by 2 if num_filters % 2 != 0: - raise ValueError("Number of output filters must be divisible by 2 in residual blocks.") - + raise ValueError( + "Number of output filters must be divisible by 2 in residual blocks." + ) + # If number of input and output channels are different, add a 1x1 conv to use as the # identity tensor to which we add the residual at the end x_identity = x_in if x_in.shape[-1] != num_filters: x_identity = conv1(num_filters)(x_in) - if batch_norm: x_identity = BatchNormalization()(x_identity) - + if batch_norm: + x_identity = BatchNormalization()(x_identity) + # Bottleneck: 1x1 -> 3x3 -> 1x1 -> Add residual to identity x = conv1(num_filters // 2)(x_in) - if batch_norm: x = BatchNormalization()(x) + if batch_norm: + x = BatchNormalization()(x) x = conv3(num_filters // 2)(x) - if batch_norm: x = BatchNormalization()(x) + if batch_norm: + x = BatchNormalization()(x) x = conv1(num_filters)(x) - if batch_norm: x = BatchNormalization()(x) + if batch_norm: + x = BatchNormalization()(x) x_out = Add()([x_identity, x]) return x_out - diff --git a/sleap/nn/architectures/densenet.py b/sleap/nn/architectures/densenet.py index 826b27263..97d840ac0 100644 --- a/sleap/nn/architectures/densenet.py +++ b/sleap/nn/architectures/densenet.py @@ -14,6 +14,7 @@ from keras import backend, layers, models import keras.utils as keras_utils + def dense_block(x, blocks, name): """A dense block. # Arguments @@ -24,7 +25,7 @@ def dense_block(x, blocks, name): output tensor for the block. """ for i in range(blocks): - x = conv_block(x, 32, name=name + '_block' + str(i + 1)) + x = conv_block(x, 32, name=name + "_block" + str(i + 1)) return x @@ -37,14 +38,16 @@ def transition_block(x, reduction, name): # Returns output tensor for the block. """ - bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1 - x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, - name=name + '_bn')(x) - x = layers.Activation('relu', name=name + '_relu')(x) - x = layers.Conv2D(int(backend.int_shape(x)[bn_axis] * reduction), 1, - use_bias=False, - name=name + '_conv')(x) - x = layers.AveragePooling2D(2, strides=2, name=name + '_pool')(x) + bn_axis = 3 if backend.image_data_format() == "channels_last" else 1 + x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + "_bn")(x) + x = layers.Activation("relu", name=name + "_relu")(x) + x = layers.Conv2D( + int(backend.int_shape(x)[bn_axis] * reduction), + 1, + use_bias=False, + name=name + "_conv", + )(x) + x = layers.AveragePooling2D(2, strides=2, name=name + "_pool")(x) return x @@ -57,30 +60,24 @@ def conv_block(x, growth_rate, name): # Returns Output tensor for the block. """ - bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1 - x1 = layers.BatchNormalization(axis=bn_axis, - epsilon=1.001e-5, - name=name + '_0_bn')(x) - x1 = layers.Activation('relu', name=name + '_0_relu')(x1) - x1 = layers.Conv2D(4 * growth_rate, 1, - use_bias=False, - name=name + '_1_conv')(x1) - x1 = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, - name=name + '_1_bn')(x1) - x1 = layers.Activation('relu', name=name + '_1_relu')(x1) - x1 = layers.Conv2D(growth_rate, 3, - padding='same', - use_bias=False, - name=name + '_2_conv')(x1) - x = layers.Concatenate(axis=bn_axis, name=name + '_concat')([x, x1]) + bn_axis = 3 if backend.image_data_format() == "channels_last" else 1 + x1 = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + "_0_bn")( + x + ) + x1 = layers.Activation("relu", name=name + "_0_relu")(x1) + x1 = layers.Conv2D(4 * growth_rate, 1, use_bias=False, name=name + "_1_conv")(x1) + x1 = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + "_1_bn")( + x1 + ) + x1 = layers.Activation("relu", name=name + "_1_relu")(x1) + x1 = layers.Conv2D( + growth_rate, 3, padding="same", use_bias=False, name=name + "_2_conv" + )(x1) + x = layers.Concatenate(axis=bn_axis, name=name + "_concat")([x, x1]) return x -def DenseNet(blocks, - output_channels, - input_tensor=None, - input_shape=None, - **kwargs): +def DenseNet(blocks, output_channels, input_tensor=None, input_shape=None, **kwargs): """Instantiates the DenseNet architecture. Optionally loads weights pre-trained on ImageNet. Note that the data format convention used by the model is @@ -123,7 +120,6 @@ def DenseNet(blocks, or invalid input shape. """ - if input_tensor is None: img_input = layers.Input(shape=input_shape) else: @@ -132,29 +128,29 @@ def DenseNet(blocks, else: img_input = input_tensor - bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1 + bn_axis = 3 if backend.image_data_format() == "channels_last" else 1 x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)))(img_input) - x = layers.Conv2D(64, 7, strides=2, use_bias=False, name='conv1/conv')(x) - x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, name='conv1/bn')(x) - x = layers.Activation('relu', name='conv1/relu')(x) + x = layers.Conv2D(64, 7, strides=2, use_bias=False, name="conv1/conv")(x) + x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name="conv1/bn")(x) + x = layers.Activation("relu", name="conv1/relu")(x) x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)))(x) - x = layers.MaxPooling2D(3, strides=2, name='pool1')(x) + x = layers.MaxPooling2D(3, strides=2, name="pool1")(x) - x = dense_block(x, blocks[0], name='conv2') - x = transition_block(x, 0.5, name='pool2') - x = dense_block(x, blocks[1], name='conv3') - x = transition_block(x, 0.5, name='pool3') - x = dense_block(x, blocks[2], name='conv4') - x = transition_block(x, 0.5, name='pool4') - x = dense_block(x, blocks[3], name='conv5') + x = dense_block(x, blocks[0], name="conv2") + x = transition_block(x, 0.5, name="pool2") + x = dense_block(x, blocks[1], name="conv3") + x = transition_block(x, 0.5, name="pool3") + x = dense_block(x, blocks[2], name="conv4") + x = transition_block(x, 0.5, name="pool4") + x = dense_block(x, blocks[3], name="conv5") - x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, name='bn')(x) - x = layers.Activation('relu', name='relu')(x) + x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name="bn")(x) + x = layers.Activation("relu", name="relu")(x) - x = layers.Conv2D(filters=output_channels, kernel_size=(3, 3), padding="same", name="output")(x) + x = layers.Conv2D( + filters=output_channels, kernel_size=(3, 3), padding="same", name="output" + )(x) # Ensure that the model takes into account # any potential predecessors of `input_tensor`. @@ -165,13 +161,13 @@ def DenseNet(blocks, # Create model. if blocks == [6, 12, 24, 16]: - model = models.Model(inputs, x, name='densenet121') + model = models.Model(inputs, x, name="densenet121") elif blocks == [6, 12, 32, 32]: - model = models.Model(inputs, x, name='densenet169') + model = models.Model(inputs, x, name="densenet169") elif blocks == [6, 12, 48, 32]: - model = models.Model(inputs, x, name='densenet201') + model = models.Model(inputs, x, name="densenet201") else: - model = models.Model(inputs, x, name='densenet') + model = models.Model(inputs, x, name="densenet") return model @@ -215,4 +211,4 @@ def DenseNet(blocks, # include_top, weights, # input_tensor, input_shape, # pooling, classes, -# **kwargs) \ No newline at end of file +# **kwargs) diff --git a/sleap/nn/architectures/hourglass.py b/sleap/nn/architectures/hourglass.py index 2f33e343a..f520c8ec3 100644 --- a/sleap/nn/architectures/hourglass.py +++ b/sleap/nn/architectures/hourglass.py @@ -1,7 +1,21 @@ import attr -from sleap.nn.architectures.common import residual_block, expand_to_n, conv, conv1, conv3 -from keras.layers import Conv2D, BatchNormalization, Add, MaxPool2D, UpSampling2D, Concatenate, Conv2DTranspose +from sleap.nn.architectures.common import ( + residual_block, + expand_to_n, + conv, + conv1, + conv3, +) +from keras.layers import ( + Conv2D, + BatchNormalization, + Add, + MaxPool2D, + UpSampling2D, + Concatenate, + Conv2DTranspose, +) @attr.s(auto_attribs=True) @@ -33,16 +47,18 @@ class StackedHourglass: by concatenating with intermediate outputs upsampling_layers: Use upsampling instead of transposed convolutions. interp: Method to use for interpolation when upsampling smaller features. + initial_stride: Stride of first convolution to use for reducing input resolution. """ - num_hourglass_blocks: int = 3 + + num_stacks: int = 3 num_filters: int = 32 depth: int = 3 batch_norm: bool = True intermediate_inputs: bool = True upsampling_layers: bool = True interp: str = "bilinear" - + initial_stride: int = 1 def output(self, x_in, num_output_channels): """ @@ -60,7 +76,15 @@ def output(self, x_in, num_output_channels): return stacked_hourglass(x_in, num_output_channels, **attr.asdict(self)) -def hourglass_block(x_in, num_output_channels, num_filters, depth=3, batch_norm=True, upsampling_layers=True, interp="bilinear"): +def hourglass_block( + x_in, + num_output_channels, + num_filters, + depth=3, + batch_norm=True, + upsampling_layers=True, + interp="bilinear", +): """Creates a single hourglass block. This function builds an hourglass block from residual blocks and max pooling. @@ -94,15 +118,21 @@ def hourglass_block(x_in, num_output_channels, num_filters, depth=3, batch_norm= x_out: tf.Tensor of the output of the block of the same width and height as the input with `num_output_channels` channels. """ - + # Check if input tensor has the right number of channels if x_in.shape[-1] != num_filters: - raise ValueError("Input tensor must have the same number of channels as the intermediate output of the hourglass (%d)." % num_filters) - + raise ValueError( + "Input tensor must have the same number of channels as the intermediate output of the hourglass (%d)." + % num_filters + ) + # Check if input tensor has the right height/width for pooling given depth - if x_in.shape[-2] % (2**depth) != 0 or x_in.shape[-2] % (2**depth) != 0: - raise ValueError("Input tensor must have width and height dimensions divisible by %d." % (2**depth)) - + if x_in.shape[-2] % (2 ** depth) != 0 or x_in.shape[-2] % (2 ** depth) != 0: + raise ValueError( + "Input tensor must have width and height dimensions divisible by %d." + % (2 ** depth) + ) + # Down x = x_in blocks_down = [] @@ -110,41 +140,59 @@ def hourglass_block(x_in, num_output_channels, num_filters, depth=3, batch_norm= x = residual_block(x, num_filters, batch_norm) blocks_down.append(x) x = MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x) - + x = residual_block(x, num_filters, batch_norm) - + # Middle x_identity = residual_block(x, num_filters, batch_norm) x = residual_block(x, num_filters, batch_norm) x = residual_block(x, num_filters, batch_norm) x = residual_block(x, num_filters, batch_norm) x = Add()([x_identity, x]) - + # Up for x_down in blocks_down[::-1]: x_down = residual_block(x_down, num_filters, batch_norm) if upsampling_layers: - x = UpSampling2D(size=(2,2), interpolation=interp)(x) + x = UpSampling2D(size=(2, 2), interpolation=interp)(x) else: - x = Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal")(x) + x = Conv2DTranspose( + num_filters, + kernel_size=3, + strides=2, + padding="same", + activation="relu", + kernel_initializer="glorot_normal", + )(x) x = Add()([x_down, x]) x = residual_block(x, num_filters, batch_norm) - + # Head x = conv1(num_filters)(x) - if batch_norm: x = BatchNormalization()(x) - + if batch_norm: + x = BatchNormalization()(x) + x_out = conv1(num_output_channels, activation="linear")(x) - + x = conv1(num_filters, activation="linear")(x) x_ = conv1(num_filters, activation="linear")(x_out) x = Add()([x_in, x, x_]) - + return x, x_out -def stacked_hourglass(x_in, num_output_channels, num_hourglass_blocks=3, num_filters=32, depth=3, batch_norm=True, - intermediate_inputs=True, upsampling_layers=True, interp="bilinear"): +def stacked_hourglass( + x_in, + num_output_channels, + num_stacks=3, + num_filters=32, + depth=3, + batch_norm=True, + intermediate_inputs=True, + upsampling_layers=True, + interp="bilinear", + initial_stride=1, +): """Stacked hourglass block. This function builds and connects multiple hourglass blocks. See `hourglass` for @@ -172,40 +220,49 @@ def stacked_hourglass(x_in, num_output_channels, num_hourglass_blocks=3, num_fil by concatenating with intermediate outputs upsampling_layers: Use upsampling instead of transposed convolutions. interp: Method to use for interpolation when upsampling smaller features. + initial_stride: Stride of first convolution to use for reducing input resolution. Returns: x_outs: List of tf.Tensors of the output of the block of the same width and height as the input with `num_output_channels` channels. """ + # Expand block-specific parameters if scalars provided + num_filters = expand_to_n(num_filters, num_stacks) + depth = expand_to_n(depth, num_stacks) + batch_norm = expand_to_n(batch_norm, num_stacks) + upsampling_layers = expand_to_n(upsampling_layers, num_stacks) + interp = expand_to_n(interp, num_stacks) + # Initial downsampling - x = conv(num_filters, kernel_size=(7, 7))(x_in) + x = conv(num_filters[0], kernel_size=(7, 7), strides=initial_stride)(x_in) # Batchnorm after the intial down sampling - if batch_norm: + if batch_norm[0]: x = BatchNormalization()(x) - # Expand block-specific parameters if scalars provided - num_filters = expand_to_n(num_filters, num_hourglass_blocks) - depth = expand_to_n(depth, num_hourglass_blocks) - batch_norm = expand_to_n(batch_norm, num_hourglass_blocks) - upsampling_layers = expand_to_n(upsampling_layers, num_hourglass_blocks) - interp = expand_to_n(interp, num_hourglass_blocks) - # Make sure first block gets the right number of channels - x = x_in + # x = x_in if x.shape[-1] != num_filters[0]: x = residual_block(x, num_filters[0], batch_norm[0]) - + # Create individual hourglasses and collect intermediate outputs + x_in = x x_outs = [] - for i in range(num_hourglass_blocks): + for i in range(num_stacks): if i > 0 and intermediate_inputs: x = Concatenate()([x, x_in]) x = residual_block(x, num_filters[i], batch_norm[i]) - x, x_out = hourglass_block(x, num_output_channels, num_filters[i], depth=depth[i], batch_norm=batch_norm[i], upsampling_layers=upsampling_layers[i], interp=interp[i]) + x, x_out = hourglass_block( + x, + num_output_channels, + num_filters[i], + depth=depth[i], + batch_norm=batch_norm[i], + upsampling_layers=upsampling_layers[i], + interp=interp[i], + ) x_outs.append(x_out) - - return x_outs + return x_outs diff --git a/sleap/nn/architectures/leap.py b/sleap/nn/architectures/leap.py index b943347af..51b418343 100644 --- a/sleap/nn/architectures/leap.py +++ b/sleap/nn/architectures/leap.py @@ -48,7 +48,15 @@ def output(self, x_in, num_output_channels): return leap_cnn(x_in, num_output_channels, **attr.asdict(self)) -def leap_cnn(x_in, num_output_channels, down_blocks=3, up_blocks=3, upsampling_layers=True, num_filters=64, interp="bilinear"): +def leap_cnn( + x_in, + num_output_channels, + down_blocks=3, + up_blocks=3, + upsampling_layers=True, + num_filters=64, + interp="bilinear", +): """LEAP CNN block. Implementation generalized from original paper (`Pereira et al., 2019 @@ -74,25 +82,50 @@ def leap_cnn(x_in, num_output_channels, down_blocks=3, up_blocks=3, upsampling_l """ # Check if input tensor has the right height/width for pooling given depth - if x_in.shape[-2] % (2**down_blocks) != 0 or x_in.shape[-2] % (2**down_blocks) != 0: - raise ValueError("Input tensor must have width and height dimensions divisible by %d." % (2**down_blocks)) + if ( + x_in.shape[-2] % (2 ** down_blocks) != 0 + or x_in.shape[-2] % (2 ** down_blocks) != 0 + ): + raise ValueError( + "Input tensor must have width and height dimensions divisible by %d." + % (2 ** down_blocks) + ) x = x_in for i in range(down_blocks): - x = Conv2D(num_filters * (2 ** i), kernel_size=3, padding="same", activation="relu")(x) - x = Conv2D(num_filters * (2 ** i), kernel_size=3, padding="same", activation="relu")(x) - x = Conv2D(num_filters * (2 ** i), kernel_size=3, padding="same", activation="relu")(x) + x = Conv2D( + num_filters * (2 ** i), kernel_size=3, padding="same", activation="relu" + )(x) + x = Conv2D( + num_filters * (2 ** i), kernel_size=3, padding="same", activation="relu" + )(x) + x = Conv2D( + num_filters * (2 ** i), kernel_size=3, padding="same", activation="relu" + )(x) x = MaxPool2D(pool_size=2, strides=2, padding="same")(x) for i in range(up_blocks, 0, -1): if upsampling_layers: x = UpSampling2D(interpolation=interp)(x) else: - x = Conv2DTranspose(num_filters * (2 ** i), kernel_size=3, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal")(x) - x = Conv2D(num_filters * (2 ** i), kernel_size=3, padding="same", activation="relu")(x) - x = Conv2D(num_filters * (2 ** i), kernel_size=3, padding="same", activation="relu")(x) - - x = Conv2D(num_output_channels, kernel_size=3, padding="same", activation="linear")(x) + x = Conv2DTranspose( + num_filters * (2 ** i), + kernel_size=3, + strides=2, + padding="same", + activation="relu", + kernel_initializer="glorot_normal", + )(x) + x = Conv2D( + num_filters * (2 ** i), kernel_size=3, padding="same", activation="relu" + )(x) + x = Conv2D( + num_filters * (2 ** i), kernel_size=3, padding="same", activation="relu" + )(x) + + x = Conv2D(num_output_channels, kernel_size=3, padding="same", activation="linear")( + x + ) return x diff --git a/sleap/nn/architectures/resnet.py b/sleap/nn/architectures/resnet.py new file mode 100644 index 000000000..4589fa1f3 --- /dev/null +++ b/sleap/nn/architectures/resnet.py @@ -0,0 +1,112 @@ +import tensorflow as tf +import keras +from keras import applications + +import attr + + +@attr.s(auto_attribs=True) +class ResNet50: + """ResNet50 pretrained backbone. + + Args: + x_in: Input 4-D tf.Tensor or instantiated layer. + num_output_channels: The number of output channels of the block. + upsampling_layers: Use upsampling instead of transposed convolutions. + interp: Method to use for interpolation when upsampling smaller features. + up_blocks: Number of upsampling steps to perform. The backbone reduces + the output scale by 1/32. If set to 5, outputs will be upsampled to the + input resolution. + refine_conv_up: If true, applies a 1x1 conv after each upsampling step. + pretrained: Load pretrained ImageNet weights for transfer learning. If + False, random weights are used for initialization. + """ + + upsampling_layers: bool = True + interp: str = "bilinear" + up_blocks: int = 5 + refine_conv_up: bool = False + pretrained: bool = True + + def output(self, x_in, num_output_channels): + """ + Generate a tensorflow graph for the backbone and return the output tensor. + + Args: + x_in: Input 4-D tf.Tensor or instantiated layer. Must have height and width + that are divisible by `2^down_blocks. + num_output_channels: The number of output channels of the block. These + are the final output tensors on which intermediate supervision may be + applied. + + Returns: + x_out: tf.Tensor of the output of the block of with `num_output_channels` channels. + """ + return resnet50(x_in, num_output_channels, **attr.asdict(self)) + + @property + def down_blocks(self): + """Returns the number of downsampling steps in the model.""" + + # This is a fixed constant for ResNet50. + return 5 + + @property + def output_scale(self): + """Returns relative scaling factor of this backbone.""" + + return 1 / (2 ** (self.down_blocks - self.up_blocks)) + + +def preprocess_input(X): + """Rescale input to [-1, 1] and tile if not RGB.""" + X = (X * 2) - 1 + + if tf.shape(X)[-1] != 3: + X = tf.tile(X, [1, 1, 1, 3]) + + return X + + +def resnet50( + x_in, + num_output_channels, + up_blocks=5, + upsampling_layers=True, + interp="bilinear", + refine_conv_up=False, + pretrained=True, +): + """Build ResNet50 backbone.""" + + # Input should be rescaled from [0, 1] to [-1, 1] and needs to be 3 channels (RGB) + x = keras.layers.Lambda(preprocess_input)(x_in) + + # Automatically downloads weights + resnet_model = applications.ResNet50( + include_top=False, + input_shape=(int(x_in.shape[-3]), int(x_in.shape[-2]), 3), + weights="imagenet" if pretrained else None, + ) + + # Output size is reduced by factor of 32 (2 ** 5) + x = resnet_model(x) + + for i in range(up_blocks): + if upsampling_layers: + x = keras.layers.UpSampling2D(size=(2, 2), interpolation=interp)(x) + else: + x = keras.layers.Conv2DTranspose( + 2 ** (8 - i), + kernel_size=3, + strides=2, + padding="same", + kernel_initializer="glorot_normal", + )(x) + + if refine_conv_up: + x = keras.layers.Conv2D(2 ** (8 - i), kernel_size=1, padding="same")(x) + + x = keras.layers.Conv2D(num_output_channels, (3, 3), padding="same")(x) + + return x diff --git a/sleap/nn/architectures/unet.py b/sleap/nn/architectures/unet.py index 6df5e02e2..75e37b4cd 100644 --- a/sleap/nn/architectures/unet.py +++ b/sleap/nn/architectures/unet.py @@ -29,6 +29,7 @@ class UNet: interp: Method to use for interpolation when upsampling smaller features. """ + down_blocks: int = 3 up_blocks: int = 3 convs_per_depth: int = 2 @@ -105,8 +106,17 @@ def output(self, x_in, num_output_channels): return stacked_unet(x_in, num_output_channels, **attr.asdict(self)) -def unet(x_in, num_output_channels, down_blocks=3, up_blocks=3, convs_per_depth=2, num_filters=16, - kernel_size=5, upsampling_layers=True, interp="bilinear"): +def unet( + x_in, + num_output_channels, + down_blocks=3, + up_blocks=3, + convs_per_depth=2, + num_filters=16, + kernel_size=5, + upsampling_layers=True, + interp="bilinear", +): """U-net block. Implementation based off of `CARE @@ -137,50 +147,73 @@ def unet(x_in, num_output_channels, down_blocks=3, up_blocks=3, convs_per_depth= """ # Check if input tensor has the right height/width for pooling given depth - if x_in.shape[-2] % (2**down_blocks) != 0 or x_in.shape[-2] % (2**down_blocks) != 0: - raise ValueError("Input tensor must have width and height dimensions divisible by %d." % (2**down_blocks)) + if ( + x_in.shape[-2] % (2 ** down_blocks) != 0 + or x_in.shape[-2] % (2 ** down_blocks) != 0 + ): + raise ValueError( + "Input tensor must have width and height dimensions divisible by %d." + % (2 ** down_blocks) + ) # Ensure we have a tuple in case scalar provided kernel_size = expand_to_n(kernel_size, 2) # Input tensor x = x_in - + # Downsampling skip_layers = [] for n in range(down_blocks): for i in range(convs_per_depth): x = conv(num_filters * 2 ** n, kernel_size=kernel_size)(x) skip_layers.append(x) - x = MaxPool2D(pool_size=(2,2))(x) + x = MaxPool2D(pool_size=(2, 2))(x) # Middle for i in range(convs_per_depth - 1): x = conv(num_filters * 2 ** down_blocks, kernel_size=kernel_size)(x) - x = conv(num_filters * 2 ** max(0, down_blocks-1), kernel_size=kernel_size)(x) + x = conv(num_filters * 2 ** max(0, down_blocks - 1), kernel_size=kernel_size)(x) # Upsampling (with skips) - for n in range(down_blocks-1, down_blocks-up_blocks-1, -1): + for n in range(down_blocks - 1, down_blocks - up_blocks - 1, -1): if upsampling_layers: - x = UpSampling2D(size=(2,2), interpolation=interp)(x) + x = UpSampling2D(size=(2, 2), interpolation=interp)(x) else: - x = Conv2DTranspose(num_filters * 2 ** n, kernel_size=kernel_size, strides=2, padding="same", activation="relu", kernel_initializer="glorot_normal")(x) + x = Conv2DTranspose( + num_filters * 2 ** n, + kernel_size=kernel_size, + strides=2, + padding="same", + activation="relu", + kernel_initializer="glorot_normal", + )(x) x = Concatenate(axis=-1)([x, skip_layers[n]]) - + for i in range(convs_per_depth - 1): x = conv(num_filters * 2 ** n, kernel_size=kernel_size)(x) - x = conv(num_filters * 2 ** max(0, n-1), kernel_size=kernel_size)(x) - + x = conv(num_filters * 2 ** max(0, n - 1), kernel_size=kernel_size)(x) + # Final layer x_out = conv(num_output_channels, activation="linear")(x) return x_out -def stacked_unet(x_in, num_output_channels, num_stacks=3, depth=3, convs_per_depth=2, num_filters=16, kernel_size=5, - upsampling_layers=True, intermediate_inputs=True, interp="bilinear"): +def stacked_unet( + x_in, + num_output_channels, + num_stacks=3, + depth=3, + convs_per_depth=2, + num_filters=16, + kernel_size=5, + upsampling_layers=True, + intermediate_inputs=True, + interp="bilinear", +): """Stacked U-net block. See `unet` for more specifics on the implementation. @@ -226,11 +259,17 @@ def stacked_unet(x_in, num_output_channels, num_stacks=3, depth=3, convs_per_dep if i > 0 and intermediate_inputs: x = Concatenate()([x, x_in]) - x_out = unet(x, num_output_channels, depth=depth[i], convs_per_depth=convs_per_depth[i], - num_filters=num_filters[i], kernel_size=kernel_size[i], - upsampling_layers=upsampling_layers[i], interp=interp[i]) + x_out = unet( + x, + num_output_channels, + depth=depth[i], + convs_per_depth=convs_per_depth[i], + num_filters=num_filters[i], + kernel_size=kernel_size[i], + upsampling_layers=upsampling_layers[i], + interp=interp[i], + ) x_outs.append(x_out) x = x_out - - return x_outs + return x_outs diff --git a/sleap/nn/augmentation.py b/sleap/nn/augmentation.py index 37fcd017e..dfe097a3d 100644 --- a/sleap/nn/augmentation.py +++ b/sleap/nn/augmentation.py @@ -57,7 +57,9 @@ def __attrs_post_init__(self): # Setup batching all_idx = np.arange(self.num_samples) - self.batches = np.array_split(all_idx, np.ceil(self.num_samples / self.batch_size)) + self.batches = np.array_split( + all_idx, np.ceil(self.num_samples / self.batch_size) + ) # Initial shuffling if self.shuffle_initially: @@ -67,10 +69,16 @@ def __attrs_post_init__(self): # TODO: translation? self.aug_stack = [] if self.rotation is not None: - self.rotation = self.rotation if isinstance(self.rotation, tuple) else (-self.rotation, self.rotation) + self.rotation = ( + self.rotation + if isinstance(self.rotation, tuple) + else (-self.rotation, self.rotation) + ) if self.scale is not None and self.scale[0] != self.scale[1]: self.scale = (min(self.scale), max(self.scale)) - self.aug_stack.append(imgaug.augmenters.Affine(rotate=self.rotation, scale=self.scale)) + self.aug_stack.append( + imgaug.augmenters.Affine(rotate=self.rotation, scale=self.scale) + ) else: self.aug_stack.append(imgaug.augmenters.Affine(rotate=self.rotation)) @@ -110,11 +118,13 @@ def shuffle(self, batches_only=False): # Re-batch after shuffling all_idx = np.arange(self.num_samples) np.random.shuffle(all_idx) - self.batches = np.array_split(all_idx, np.ceil(self.num_samples / self.batch_size)) - + self.batches = np.array_split( + all_idx, np.ceil(self.num_samples / self.batch_size) + ) + def __len__(self): return len(self.batches) - + def __getitem__(self, batch_idx): aug_det = self.aug.to_deterministic() idx = self.batches[batch_idx] @@ -137,18 +147,18 @@ def __getitem__(self, batch_idx): # Combine each list of point arrays (per frame) to single KeypointsOnImage # points: frames -> instances -> point_array frames_in_batch = [self.points[i] for i in idx] - points_per_instance_per_frame = [[pa.shape[0] for pa in frame] for frame in frames_in_batch] + points_per_instance_per_frame = [ + [pa.shape[0] for pa in frame] for frame in frames_in_batch + ] koi_in_frame = [] for i, frame in enumerate(frames_in_batch): if len(frame): koi = imgaug.augmentables.kps.KeypointsOnImage.from_xy_array( - np.concatenate(frame), - shape=X[i].shape) + np.concatenate(frame), shape=X[i].shape + ) else: - koi = imgaug.augmentables.kps.KeypointsOnImage( - [], - shape=X[i].shape) + koi = imgaug.augmentables.kps.KeypointsOnImage([], shape=X[i].shape) koi_in_frame.append(koi) # Augment KeypointsOnImage @@ -165,7 +175,7 @@ def __getitem__(self, batch_idx): frame_point_arrays = [] offset = 0 for point_count in points_per_instance_per_frame[i]: - inst_points = frame[offset:offset+point_count] + inst_points = frame[offset : offset + point_count] frame_point_arrays.append(inst_points) offset += point_count split_points.append(frame_point_arrays) @@ -192,18 +202,22 @@ def make_cattr(X=None, Y=None, Points=None): # parameters. aug_cattr.register_unstructure_hook( Augmenter, - lambda x: - attr.asdict(x, - filter=attr.filters.exclude( - attr.fields(Augmenter).X, - attr.fields(Augmenter).Y, - attr.fields(Augmenter).Points))) + lambda x: attr.asdict( + x, + filter=attr.filters.exclude( + attr.fields(Augmenter).X, + attr.fields(Augmenter).Y, + attr.fields(Augmenter).Points, + ), + ), + ) # We the user needs to unstructure, what images, outputs, and points should # they use. We didn't serialize these, just the parameters. if X is not None: - aug_cattr.register_structure_hook(Augmenter, - lambda x: Augmenter(X=X,Y=Y,Points=Points, **x)) + aug_cattr.register_structure_hook( + Augmenter, lambda x: Augmenter(X=X, Y=Y, Points=Points, **x) + ) return aug_cattr @@ -211,27 +225,24 @@ def make_cattr(X=None, Y=None, Points=None): def demo_augmentation(): from sleap.io.dataset import Labels from sleap.nn.datagen import generate_training_data - from sleap.nn.datagen import generate_confmaps_from_points, generate_pafs_from_points - - data_path = "tests/data/json_format_v2/centered_pair_predictions.json" - # data_path = "tests/data/json_format_v2/minimal_instance.json" -# data_path = "tests/data/json_format_v1/test.json" + from sleap.nn.datagen import ( + generate_confmaps_from_points, + generate_pafs_from_points, + ) + data_path = "tests/data/json_format_v1/centered_pair.json" labels = Labels.load_json(data_path) -# labels.labeled_frames = labels.labeled_frames[123:323:10] - # Generate raw training data skeleton = labels.skeletons[0] - imgs, points = generate_training_data(labels, params = dict( - scale = 1, - instance_crop = True, - min_crop_size = 0, - negative_samples = 0)) + imgs, points = generate_training_data( + labels, + params=dict(scale=1, instance_crop=True, min_crop_size=0, negative_samples=0), + ) shape = (imgs.shape[1], imgs.shape[2]) def datagen_from_points(points): -# return generate_pafs_from_points(points, skeleton, shape) + # return generate_pafs_from_points(points, skeleton, shape) return generate_confmaps_from_points(points, skeleton, shape) # Augment @@ -244,12 +255,13 @@ def datagen_from_points(points): from PySide2.QtWidgets import QApplication # Visualize augmented training data - vid = Video.from_numpy(imgs*255) + vid = Video.from_numpy(imgs * 255) app = QApplication([]) demo_confmaps(aug_out, vid) -# demo_pafs(aug_out, vid) + # demo_pafs(aug_out, vid) app.exec_() + def demo_bad_augmentation(): from sleap.io.dataset import Labels from sleap.nn.datagen import generate_images, generate_confidence_maps @@ -267,7 +279,7 @@ def demo_bad_augmentation(): confmaps = generate_confidence_maps(labels) # Augment - aug = Augmenter(X=imgs, Y=confmaps, scale=(.5, 2)) + aug = Augmenter(X=imgs, Y=confmaps, scale=(0.5, 2)) imgs, confmaps = aug[0] from sleap.io.video import Video @@ -276,11 +288,12 @@ def demo_bad_augmentation(): from PySide2.QtWidgets import QApplication # Visualize augmented training data - vid = Video.from_numpy(imgs*255) + vid = Video.from_numpy(imgs * 255) app = QApplication([]) demo_confmaps(confmaps, vid) app.exec_() + if __name__ == "__main__": demo_augmentation() diff --git a/sleap/nn/datagen.py b/sleap/nn/datagen.py index d3860b678..3507e30d1 100644 --- a/sleap/nn/datagen.py +++ b/sleap/nn/datagen.py @@ -10,6 +10,7 @@ from sleap.io.dataset import Labels + def generate_training_data(labels, params): """ Generate imgs (ndarray) and points (list) to use for training. @@ -33,53 +34,72 @@ def generate_training_data(labels, params): resize_hack = not params["instance_crop"] - imgs = generate_images(labels, params["scale"], - frame_limit=params.get("frame_limit", None), - resize_hack=resize_hack) + imgs = generate_images( + labels, + params["scale"], + frame_limit=params.get("frame_limit", None), + resize_hack=resize_hack, + ) - points = generate_points(labels, params["scale"], - frame_limit=params.get("frame_limit", None)) + points = generate_points( + labels, params["scale"], frame_limit=params.get("frame_limit", None) + ) if params["instance_crop"]: # Crop and include any *random* negative samples imgs, points = instance_crops( - imgs, points, - min_crop_size = params["min_crop_size"], - negative_samples = params["negative_samples"]) + imgs, + points, + min_crop_size=params["min_crop_size"], + negative_samples=params["negative_samples"], + ) # Include any *specific* negative samples imgs, points = add_negative_anchor_crops( - labels, - imgs, points, - scale=params["scale"]) + labels, imgs, points, scale=params["scale"] + ) return imgs, points -def generate_images(labels:Labels, scale: float=1.0, - resize_hack: bool=True, frame_limit: int=None) -> np.ndarray: + +def generate_images( + labels: Labels, + scale: float = 1.0, + resize_hack: bool = True, + frame_limit: int = None, +) -> np.ndarray: """ Generate a ndarray of the image data for any user labeled frames. Wrapper that calls generate_images_from_list() with list of all frames that were labeled by user. """ - frame_list = [(lf.video, lf.frame_idx) - for lf in labels.user_labeled_frames[:frame_limit]] + frame_list = [ + (lf.video, lf.frame_idx) for lf in labels.user_labeled_frames[:frame_limit] + ] return generate_images_from_list(labels, frame_list, scale, resize_hack) -def generate_points(labels:Labels, scale: float=1.0, frame_limit: int=None) -> list: + +def generate_points( + labels: Labels, scale: float = 1.0, frame_limit: int = None +) -> list: """Generates point data for instances for any user labeled frames. Wrapper that calls generate_points_from_list() with list of all frames that were labeled by user. """ - frame_list = [(lf.video, lf.frame_idx) - for lf in labels.user_labeled_frames[:frame_limit]] + frame_list = [ + (lf.video, lf.frame_idx) for lf in labels.user_labeled_frames[:frame_limit] + ] return generate_points_from_list(labels, frame_list, scale) + def generate_images_from_list( - labels:Labels, frame_list: List[Tuple], - scale: float=1.0, resize_hack: bool=True) -> np.ndarray: + labels: Labels, + frame_list: List[Tuple], + scale: float = 1.0, + resize_hack: bool = True, +) -> np.ndarray: """ Generate a ndarray of the image data for given list of frames @@ -98,11 +118,11 @@ def generate_images_from_list( # rescale by factor y, x, c = img.shape if scale != 1.0 or resize_hack: - y_scaled, x_scaled = int(y//(1/scale)), int(x//(1/scale)) + y_scaled, x_scaled = int(y // (1 / scale)), int(x // (1 / scale)) # FIXME: hack to resize image so dimensions are divisible by 8 if resize_hack: - y_scaled, x_scaled = y_scaled//8*8, x_scaled//8*8 + y_scaled, x_scaled = y_scaled // 8 * 8, x_scaled // 8 * 8 if (x, y) != (x_scaled, y_scaled): # resize image @@ -122,7 +142,10 @@ def generate_images_from_list( return imgs -def generate_points_from_list(labels:Labels, frame_list: List[Tuple], scale: float=1.0) -> list: + +def generate_points_from_list( + labels: Labels, frame_list: List[Tuple], scale: float = 1.0 +) -> list: """Generates point data for instances in specified frames. Output is in the format expected by @@ -139,22 +162,28 @@ def generate_points_from_list(labels:Labels, frame_list: List[Tuple], scale: flo a list (each frame) of lists (each instance) of ndarrays (of points) i.e., frames -> instances -> point_array """ + def lf_points_from_singleton(lf_singleton): - if len(lf_singleton) == 0: return [] + if len(lf_singleton) == 0: + return [] lf = lf_singleton[0] - points = [inst.points_array(invisible_as_nan=True)*scale - for inst in lf.user_instances] + points = [inst.points_array * scale for inst in lf.user_instances] return points lfs = [labels.find(video, frame_idx) for (video, frame_idx) in frame_list] return list(map(lf_points_from_singleton, lfs)) -def generate_confmaps_from_points(frames_inst_points, - skeleton: Optional['Skeleton'], - shape, - node_count: Optional[int] = None, - sigma:float=5.0, scale:float=1.0, output_size=None) -> np.ndarray: + +def generate_confmaps_from_points( + frames_inst_points, + skeleton: Optional["Skeleton"], + shape, + node_count: Optional[int] = None, + sigma: float = 5.0, + scale: float = 1.0, + output_size=None, +) -> np.ndarray: """ Generates confmaps for set of frames. This is used to generate confmaps on the fly during training, @@ -175,15 +204,16 @@ def generate_confmaps_from_points(frames_inst_points, full_size = shape if output_size is None: - output_size = (shape[0] // (1/scale), shape[1] // (1/scale)) + output_size = (shape[0] // (1 / scale), shape[1] // (1 / scale)) output_size = tuple(map(int, output_size)) - ball = _get_conf_ball(output_size, sigma*scale) + ball = _get_conf_ball(output_size, sigma * scale) num_frames = len(frames_inst_points) - confmaps = np.zeros((num_frames, output_size[0], output_size[1], node_count), - dtype="float32") + confmaps = np.zeros( + (num_frames, output_size[0], output_size[1], node_count), dtype="float32" + ) for frame_idx, points_arrays in enumerate(frames_inst_points): for inst_points in points_arrays: @@ -191,12 +221,21 @@ def generate_confmaps_from_points(frames_inst_points, if not np.isnan(np.sum(inst_points[node_idx])): x = inst_points[node_idx][0] * scale y = inst_points[node_idx][1] * scale - _raster_ball(arr=confmaps[frame_idx], ball=ball, c=node_idx, x=x, y=y) + _raster_ball( + arr=confmaps[frame_idx], ball=ball, c=node_idx, x=x, y=y + ) return confmaps -def generate_pafs_from_points(frames_inst_points, skeleton, shape, - sigma:float=5.0, scale:float=1.0, output_size=None) -> np.ndarray: + +def generate_pafs_from_points( + frames_inst_points, + skeleton, + shape, + sigma: float = 5.0, + scale: float = 1.0, + output_size=None, +) -> np.ndarray: """ Generates pafs for set of frames. This is used to generate pafs on the fly during training, @@ -212,7 +251,7 @@ def generate_pafs_from_points(frames_inst_points, skeleton, shape, """ full_size = shape if output_size is None: - output_size = (shape[0] // (1/scale), shape[1] // (1/scale)) + output_size = (shape[0] // (1 / scale), shape[1] // (1 / scale)) # TODO: throw warning for truncation errors full_size = tuple(map(int, full_size)) @@ -221,8 +260,9 @@ def generate_pafs_from_points(frames_inst_points, skeleton, shape, num_frames = len(frames_inst_points) num_channels = len(skeleton.edges) * 2 - pafs = np.zeros((num_frames, output_size[0], output_size[1], num_channels), - dtype="float32") + pafs = np.zeros( + (num_frames, output_size[0], output_size[1], num_channels), dtype="float32" + ) for frame_idx, points_arrays in enumerate(frames_inst_points): for inst_points in points_arrays: for c, (src_node, dst_node) in enumerate(skeleton.edges): @@ -236,19 +276,23 @@ def generate_pafs_from_points(frames_inst_points, skeleton, shape, return pafs + def _get_conf_ball(output_size, sigma): # Pre-allocate coordinate grid xv = np.linspace(0, output_size[1] - 1, output_size[1], dtype="float32") yv = np.linspace(0, output_size[0] - 1, output_size[0], dtype="float32") XX, YY = np.meshgrid(xv, yv) - x, y = output_size[1]//2, output_size[0]//2 + x, y = output_size[1] // 2, output_size[0] // 2 ball_full = np.exp(-((YY - y) ** 2 + (XX - x) ** 2) / (2 * sigma ** 2)) - window_size = int(sigma*4) - ball_window = ball_full[y-window_size:y+window_size, x-window_size:x+window_size] + window_size = int(sigma * 4) + ball_window = ball_full[ + y - window_size : y + window_size, x - window_size : x + window_size + ] return ball_window + def _raster_ball(arr, ball, c, x, y): x, y = int(x), int(y) ball_h, ball_w = ball.shape @@ -257,8 +301,8 @@ def _raster_ball(arr, ball, c, x, y): ball_slice_y = slice(0, ball_h) ball_slice_x = slice(0, ball_w) - arr_slice_y = slice(y-ball_h//2, y+ball_h//2) - arr_slice_x = slice(x-ball_w//2, x+ball_w//2) + arr_slice_y = slice(y - ball_h // 2, y + ball_h // 2) + arr_slice_x = slice(x - ball_w // 2, x + ball_w // 2) # crop ball if it would be out of array bounds # i.e., it's close to edge @@ -275,24 +319,26 @@ def _raster_ball(arr, ball, c, x, y): if arr_slice_y.stop > out_h: cut = arr_slice_y.stop - out_h arr_slice_y = slice(arr_slice_y.start, out_h) - ball_slice_y = slice(0, ball_h-cut) + ball_slice_y = slice(0, ball_h - cut) if arr_slice_x.stop > out_w: cut = arr_slice_x.stop - out_w arr_slice_x = slice(arr_slice_x.start, out_w) - ball_slice_x = slice(0, ball_w-cut) + ball_slice_x = slice(0, ball_w - cut) - if ball_slice_x.stop <= ball_slice_x.start \ - or ball_slice_y.stop <= ball_slice_y.start: + if ( + ball_slice_x.stop <= ball_slice_x.start + or ball_slice_y.stop <= ball_slice_y.start + ): return # impose ball on array arr[arr_slice_y, arr_slice_x, c] = np.maximum( - arr[arr_slice_y, arr_slice_x, c], - ball[ball_slice_y, ball_slice_x] - ) + arr[arr_slice_y, arr_slice_x, c], ball[ball_slice_y, ball_slice_x] + ) + -def generate_confidence_maps(labels:Labels, sigma=5.0, scale=1): +def generate_confidence_maps(labels: Labels, sigma=5.0, scale=1): """Wrapper for generate_confmaps_from_points which takes labels instead of points.""" # TODO: multi-skeleton support @@ -306,16 +352,19 @@ def generate_confidence_maps(labels:Labels, sigma=5.0, scale=1): return confmaps + def _raster_pafs(arr, c, x0, y0, x1, y1, sigma): # skip if any nan - if np.isnan(np.sum((x0, y0, x1, y1))): return + if np.isnan(np.sum((x0, y0, x1, y1))): + return delta_x, delta_y = x1 - x0, y1 - y0 - edge_len = (delta_x ** 2 + delta_y ** 2) ** .5 + edge_len = (delta_x ** 2 + delta_y ** 2) ** 0.5 # skip if no distance between nodes - if edge_len == 0.0: return + if edge_len == 0.0: + return edge_x = delta_x / edge_len edge_y = delta_y / edge_len @@ -330,6 +379,7 @@ def _raster_pafs(arr, c, x0, y0, x1, y1, sigma): yy = perp_y0, perp_y0 + delta_y, perp_y1 + delta_y, perp_y1 from skimage.draw import polygon, polygon_perimeter + points_y, points_x = polygon(yy, xx, (arr.shape[0], arr.shape[1])) perim_y, perim_x = polygon_perimeter(yy, xx, shape=(arr.shape[0], arr.shape[1])) @@ -341,7 +391,8 @@ def _raster_pafs(arr, c, x0, y0, x1, y1, sigma): arr[y, x, c] = edge_x arr[y, x, c + 1] = edge_y -def generate_pafs(labels: Labels, sigma:float=5.0, scale:float=1.0) -> np.ndarray: + +def generate_pafs(labels: Labels, sigma: float = 5.0, scale: float = 1.0) -> np.ndarray: """Wrapper for generate_pafs_from_points which takes labels instead of points.""" # TODO: multi-skeleton support @@ -355,6 +406,7 @@ def generate_pafs(labels: Labels, sigma:float=5.0, scale:float=1.0) -> np.ndarra return pafs + def point_array_bounding_box(point_array: np.ndarray) -> tuple: """Returns (x0, y0, x1, y1) for box that bounds point_array.""" x0 = np.nanmin(point_array[:, 0]) @@ -363,6 +415,7 @@ def point_array_bounding_box(point_array: np.ndarray) -> tuple: y1 = np.nanmax(point_array[:, 1]) return x0, y0, x1, y1 + def pad_rect_to(x0: int, y0: int, x1: int, y1: int, pad_to: tuple, within: tuple): """Grow (x0, y0, x1, y1) so it's as large as pad_to but stays inside within. @@ -381,14 +434,14 @@ def pad_rect_to(x0: int, y0: int, x1: int, y1: int, pad_to: tuple, within: tuple * 0 <= (x1-x0) <= within w """ pad_to_y, pad_to_x = pad_to - x_margin = pad_to_x - (x1-x0) - y_margin = pad_to_y - (y1-y0) + x_margin = pad_to_x - (x1 - x0) + y_margin = pad_to_y - (y1 - y0) # initial values - x0 -= x_margin//2 - x1 += x_margin-x_margin//2 - y0 -= y_margin//2 - y1 += y_margin-y_margin//2 + x0 -= x_margin // 2 + x1 += x_margin - x_margin // 2 + y0 -= y_margin // 2 + y1 += y_margin - y_margin // 2 # adjust to stay inside within within_y, within_x = within @@ -397,36 +450,40 @@ def pad_rect_to(x0: int, y0: int, x1: int, y1: int, pad_to: tuple, within: tuple x1 = min(within_x, pad_to_x) if x1 > within_x: x1 = within_x - x0 = max(0, within_x-pad_to_x) + x0 = max(0, within_x - pad_to_x) if y0 < 0: y0 = 0 y1 = min(within_y, pad_to_y) if y1 > within_y: y1 = within_y - y0 = max(0, within_y-pad_to_y) + y0 = max(0, within_y - pad_to_y) return x0, y0, x1, y1 + def generate_centroid_points(points: list) -> list: """Takes the points for each instance and replaces it with a single centroid point.""" - centroids = [[_centroid(*point_array_bounding_box(point_array)) - for point_array in frame] for frame in points] + centroids = [ + [_centroid(*point_array_bounding_box(point_array)) for point_array in frame] + for frame in points + ] return centroids + def _to_np_point(x, y) -> np.ndarray: a = np.array((x, y)) return np.expand_dims(a, axis=0) + def _centroid(x0, y0, x1, y1) -> np.ndarray: - return _to_np_point(x = x0+(x1-x0)/2, y = y0+(y1-y0)/2) + return _to_np_point(x=x0 + (x1 - x0) / 2, y=y0 + (y1 - y0) / 2) + def instance_crops( - imgs: np.ndarray, - points: list, - min_crop_size: int=0, - negative_samples: int=0) -> Tuple[np.ndarray, List]: + imgs: np.ndarray, points: list, min_crop_size: int = 0, negative_samples: int = 0 +) -> Tuple[np.ndarray, List]: """ Take imgs, points and return imgs, points cropped around instances. @@ -455,29 +512,38 @@ def instance_crops( # Add bounding boxes for *random* negative samples if negative_samples > 0: - neg_img_idxs, neg_bbs = get_random_negative_samples(img_idxs, bbs, img_shape, negative_samples) + neg_img_idxs, neg_bbs = get_random_negative_samples( + img_idxs, bbs, img_shape, negative_samples + ) neg_imgs, neg_points = _crop_and_transform(imgs, points, neg_img_idxs, neg_bbs) - crop_imgs, crop_points = _extend_imgs_points(crop_imgs, crop_points, neg_imgs, neg_points) + crop_imgs, crop_points = _extend_imgs_points( + crop_imgs, crop_points, neg_imgs, neg_points + ) return crop_imgs, crop_points + def _crop_and_transform(imgs, points, img_idxs, bbs): crop_imgs = _crop(imgs, img_idxs, bbs) crop_points = _transform_crop_points(points, img_idxs, bbs) return crop_imgs, crop_points + def _extend_imgs_points(imgs_a, points_a, imgs_b, points_b): imgs = np.concatenate((imgs_a, imgs_b)) points = points_a + points_b return imgs, points + def _pad_bbs_to_min(bbs, min_crop_size, img_shape): padded_bbs = _pad_bbs( - bbs = bbs, - box_shape = _bb_pad_shape(bbs, min_crop_size, img_shape), - img_shape = img_shape) + bbs=bbs, + box_shape=_bb_pad_shape(bbs, min_crop_size, img_shape), + img_shape=img_shape, + ) return padded_bbs + def _bb_pad_shape(bbs, min_crop_size, img_shape): """ Given a list of bounding boxes, finds the square size which will be: @@ -491,13 +557,16 @@ def _bb_pad_shape(bbs, min_crop_size, img_shape): Returns: (size, size) tuple """ + + # TODO: Holy hardcoded fuck Batman! This really needs to get cleaned up + # Find a nicely sized box that's large enough to bound all instances max_height = max((y1 - y0 for (x0, y0, x1, y1) in bbs)) max_width = max((x1 - x0 for (x0, y0, x1, y1) in bbs)) max_dim = max(max_height, max_width) max_dim = max(max_dim, min_crop_size) - max_dim += 20 # pad - box_side = ceil(max_dim/64)*64 # round up to nearest multiple of 64 + max_dim += 20 # pad + box_side = ceil(max_dim / 64) * 64 # round up to nearest multiple of 64 # TODO: make sure we have valid box_size @@ -506,6 +575,7 @@ def _bb_pad_shape(bbs, min_crop_size, img_shape): return box_shape + def _transform_crop_points(points, img_idxs, bbs): """Takes points on the original images and returns points in bounding boxes. @@ -530,14 +600,19 @@ def _transform_crop_points(points, img_idxs, bbs): crop_points = list(map(lambda i: points[i], img_idxs)) # translate points to location w/in cropped image - crop_points = [_translate_points_array(points_array, bbs[i][0], bbs[i][1]) - for i, points_array in enumerate(crop_points)] + crop_points = [ + _translate_points_array(points_array, bbs[i][0], bbs[i][1]) + for i, points_array in enumerate(crop_points) + ] return crop_points + def _translate_points_array(points_array, x, y): - if len(points_array) == 0: return points_array - return points_array - np.asarray([x,y]) + if len(points_array) == 0: + return points_array + return points_array - np.asarray([x, y]) + def merge_boxes(box_a, box_b): """Return a box that contains both boxes.""" @@ -550,10 +625,12 @@ def merge_boxes(box_a, box_b): return (c_x1, c_y1, c_x2, c_y2) + def merge_boxes_with_overlap(boxes): """Return a list of boxes after merging any overlapping boxes.""" - if len(boxes) < 2: return boxes + if len(boxes) < 2: + return boxes first_box = boxes[0] other_boxes = boxes[1:] @@ -572,6 +649,7 @@ def merge_boxes_with_overlap(boxes): return [first_box] + other_boxes + def merge_boxes_with_overlap_and_padding(boxes, pad_factor_box, within): """ Returns a list of boxes after merging any overlapping boxes @@ -586,23 +664,28 @@ def merge_boxes_with_overlap_and_padding(boxes, pad_factor_box, within): if len(merged_boxes) == len(boxes): return merged_boxes else: - return merge_boxes_with_overlap_and_padding(merged_boxes, pad_factor_box, within) + return merge_boxes_with_overlap_and_padding( + merged_boxes, pad_factor_box, within + ) + def pad_box_to_multiple(box, pad_factor_box, within): - box_h = box[3] - box[1] # difference in y - box_w = box[2] - box[0] # difference in x + box_h = box[3] - box[1] # difference in y + box_w = box[2] - box[0] # difference in x pad_h, pad_w = pad_factor_box # Find multiple of pad_factor_box that's large enough to hold box - multiple_h, multiple_w = ceil(box_h / pad_h), ceil(box_w / pad_w) + multiple_h = ceil(box_h / pad_h) + multiple_w = ceil(box_w / pad_w) # Maintain aspect ratio multiple = max(multiple_h, multiple_w) # Return padded box - return pad_rect_to(*box, (pad_h*multiple, pad_w*multiple), within) + return pad_rect_to(*box, (pad_h * multiple, pad_w * multiple), within) + def bounding_box_nms(boxes, scores, iou_threshold): """ @@ -635,10 +718,10 @@ def bounding_box_nms(boxes, scores, iou_threshold): pick = [] # grab the coordinates of the bounding boxes - x1 = boxes[:,0] - y1 = boxes[:,1] - x2 = boxes[:,2] - y2 = boxes[:,3] + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] # compute the area of the bounding boxes, sort by scores area = (x2 - x1 + 1) * (y2 - y1 + 1) @@ -669,21 +752,23 @@ def bounding_box_nms(boxes, scores, iou_threshold): overlap = (w * h) / area[idxs[:last]] # delete all indexes from the index list that have - idxs = np.delete(idxs, np.concatenate(([last], - np.where(overlap > iou_threshold)[0]))) + idxs = np.delete( + idxs, np.concatenate(([last], np.where(overlap > iou_threshold)[0])) + ) # return the list of picked boxes return pick + def negative_anchor_crops( - labels: Labels, - negative_anchors: Dict['Video', Dict[int, Tuple]], - scale, crop_size) -> Tuple[np.ndarray, List]: + labels: Labels, negative_anchors: Dict["Video", Dict[int, Tuple]], scale, crop_size +) -> Tuple[np.ndarray, List]: """ Returns crops around *specific* negative samples from Labels object. Args: labels: the `Labels` object + negative_anchors: The anchors for negative training samples. scale: scale, should match scale given to generate_images() crop_size: the size of the crops returned by instance_crops() Returns: @@ -694,16 +779,21 @@ def negative_anchor_crops( # negative_anchors[video]: (frame_idx, x, y) for center of crop - neg_anchor_tuples = [(video, frame_idx, x, y) - for video in negative_anchors - for (frame_idx, x, y) in negative_anchors[video]] + # Filter negative anchors so we only include frames with labeled data + training_frames = [(lf.video, lf.frame_idx) for lf in labels.user_labeled_frames] + + neg_anchor_tuples = [ + (video, frame_idx, x, y) + for video in negative_anchors + for (frame_idx, x, y) in negative_anchors[video] + if (video, frame_idx) in training_frames + ] - if len(neg_anchor_tuples) == 0: return None, None + if len(neg_anchor_tuples) == 0: + return None, None - frame_list = [(video, frame_idx) - for (video, frame_idx, x, y) in neg_anchor_tuples] - anchors = [[_to_np_point(x,y)] - for (video, frame_idx, x, y) in neg_anchor_tuples] + frame_list = [(video, frame_idx) for (video, frame_idx, x, y) in neg_anchor_tuples] + anchors = [[_to_np_point(x, y)] for (video, frame_idx, x, y) in neg_anchor_tuples] imgs = generate_images_from_list(labels, frame_list, scale) points = generate_points_from_list(labels, frame_list, scale) @@ -719,39 +809,49 @@ def negative_anchor_crops( return crop_imgs, crop_points -def add_negative_anchor_crops(labels: Labels, imgs: np.ndarray, points: list, scale: float) -> Tuple[np.ndarray, List]: + +def add_negative_anchor_crops( + labels: Labels, imgs: np.ndarray, points: list, scale: float +) -> Tuple[np.ndarray, List]: """Wrapper to build and append negative anchor crops.""" # Include any *specific* negative samples neg_imgs, neg_points = negative_anchor_crops( - labels, - labels.negative_anchors, - scale=scale, - crop_size=imgs.shape[1]) + labels, labels.negative_anchors, scale=scale, crop_size=imgs.shape[1] + ) if neg_imgs is not None: imgs, points = _extend_imgs_points(imgs, points, neg_imgs, neg_points) return imgs, points + def get_random_negative_samples(img_idxs, bbs, img_shape, negative_samples): - if len(bbs) == 0: return + if len(bbs) == 0: + return frame_count = len({frame for frame in img_idxs}) - box_side = bbs[0][2] - bbs[0][0] # x1 - x0 for the first bb + box_side = bbs[0][2] - bbs[0][0] # x1 - x0 for the first bb neg_sample_list = [] # Collect negative samples (and some extras) - for _ in range(max(int(negative_samples*1.5), negative_samples+10)): + for _ in range(max(int(negative_samples * 1.5), negative_samples + 10)): # find negative sample # pick a random image sample_img_idx = random.randrange(frame_count) # pick a random box within image - x, y = random.randrange(img_shape[1] - box_side), random.randrange(img_shape[0] - box_side) - sample_bb = (x, y, x+box_side, y+box_side) + x, y = ( + random.randrange(img_shape[1] - box_side), + random.randrange(img_shape[0] - box_side), + ) + sample_bb = (x, y, x + box_side, y + box_side) - frame_bbs = [bbs[i] for i, frame in enumerate(img_idxs) if frame == sample_img_idx] - area_covered = sum(map(lambda bb: box_overlap_area(sample_bb, bb), frame_bbs))/(box_side**2) + frame_bbs = [ + bbs[i] for i, frame in enumerate(img_idxs) if frame == sample_img_idx + ] + area_covered = sum( + map(lambda bb: box_overlap_area(sample_bb, bb), frame_bbs) + ) / (box_side ** 2) # append negative sample to lists neg_sample_list.append((area_covered, sample_img_idx, sample_bb)) @@ -762,9 +862,14 @@ def get_random_negative_samples(img_idxs, bbs, img_shape, negative_samples): return neg_img_idxs[:negative_samples], neg_bbs[:negative_samples] + def _bbs_from_points(points): # List of bounding box for every instance - bbs = [point_array_bounding_box(point_array) for frame in points for point_array in frame] + bbs = [ + point_array_bounding_box(point_array) + for frame in points + for point_array in frame + ] bbs = [(int(x0), int(y0), int(x1), int(y1)) for (x0, y0, x1, y1) in bbs] # List to map bb to its img frame idx @@ -772,6 +877,7 @@ def _bbs_from_points(points): return bbs, img_idxs + def box_overlap_area(box_a, box_b): # determine the (x, y)-coordinates of the intersection rectangle xA = max(box_a[0], box_b[0]) @@ -784,16 +890,22 @@ def box_overlap_area(box_a, box_b): return inter_area + def _pad_bbs(bbs, box_shape, img_shape): return list(map(lambda bb: pad_rect_to(*bb, box_shape, img_shape), bbs)) + def _crop(imgs, img_idxs, bbs) -> np.ndarray: - imgs = [imgs[img_idxs[i], bb[1]:bb[3], bb[0]:bb[2]] for i, bb in enumerate(bbs)] # imgs[frame_idx, y0:y1, x0:x1] + imgs = [ + imgs[img_idxs[i], bb[1] : bb[3], bb[0] : bb[2]] for i, bb in enumerate(bbs) + ] # imgs[frame_idx, y0:y1, x0:x1] imgs = np.stack(imgs, axis=0) return imgs -def fullsize_points_from_crop(idx: int, point_array: np.ndarray, - bbs: list, img_idxs: list): + +def fullsize_points_from_crop( + idx: int, point_array: np.ndarray, bbs: list, img_idxs: list +): """ Map point within crop back to original image frames. @@ -807,13 +919,14 @@ def fullsize_points_from_crop(idx: int, point_array: np.ndarray, """ bb = bbs[idx] - top_left_point = ((bb[0], bb[1]),) # for (x, y) column vector + top_left_point = ((bb[0], bb[1]),) # for (x, y) column vector point_array += np.array(top_left_point) frame_idx = img_idxs[idx] return frame_idx, point_array + def demo_datagen_time(): data_path = "tests/data/json_format_v2/centered_pair_predictions.json" @@ -824,7 +937,10 @@ def demo_datagen_time(): timing_reps = 1 import timeit - t = timeit.timeit("generate_confidence_maps(labels)", number=timing_reps, globals=globals()) + + t = timeit.timeit( + "generate_confidence_maps(labels)", number=timing_reps, globals=globals() + ) t /= timing_reps print(f"confmaps time: {t} = {t/count} s/frame for {count} frames") @@ -832,28 +948,22 @@ def demo_datagen_time(): t /= timing_reps print(f"pafs time: {t} = {t/count} s/frame for {count} frames") + def demo_datagen(): - import os - data_path = "C:/Users/tdp/OneDrive/code/sandbox/leap_wt_gold_pilot/centered_pair.json" - if not os.path.exists(data_path): - data_path = "tests/data/json_format_v1/centered_pair.json" - # data_path = "tests/data/json_format_v2/minimal_instance.json" + data_path = "tests/data/json_format_v1/centered_pair.json" + data_path = "/Users/tabris/Desktop/macpaths.json.h5" - labels = Labels.load_json(data_path) - # testing - labels.negative_anchors = {labels.videos[0]: [(0, 125, 125), (0, 150, 150)]} - # labels.labeled_frames = labels.labeled_frames[123:423:10] + labels = Labels.load_file(data_path) scale = 1 imgs, points = generate_training_data( - labels = labels, - params = dict( - scale = scale, - instance_crop = True, - min_crop_size = 0, - negative_samples = 0)) + labels=labels, + params=dict( + scale=scale, instance_crop=True, min_crop_size=0, negative_samples=0 + ), + ) print("--imgs--") print(imgs.shape) @@ -886,7 +996,9 @@ def demo_datagen(): skeleton = labels.skeletons[0] img_shape = (imgs.shape[1], imgs.shape[2]) - confmaps = generate_confmaps_from_points(points, skeleton, img_shape, scale=.5, sigma=5.0*scale) + confmaps = generate_confmaps_from_points( + points, skeleton, img_shape, scale=0.5, sigma=5.0 * scale + ) print("--confmaps--") print(confmaps.shape) print(confmaps.dtype) @@ -894,7 +1006,9 @@ def demo_datagen(): demo_confmaps(confmaps, vid) - pafs = generate_pafs_from_points(points, skeleton, img_shape, scale=.5, sigma=5.0*scale) + pafs = generate_pafs_from_points( + points, skeleton, img_shape, scale=0.5, sigma=5.0 * scale + ) print("--pafs--") print(pafs.shape) print(pafs.dtype) @@ -904,5 +1018,6 @@ def demo_datagen(): app.exec_() + if __name__ == "__main__": - demo_datagen() \ No newline at end of file + demo_datagen() diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index ee72d6dcb..37aaa54bb 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -4,6 +4,7 @@ import os import json import logging + logger = logging.getLogger(__name__) import numpy as np @@ -18,9 +19,7 @@ from multiprocessing.pool import AsyncResult, ThreadPool from time import time, clock -from typing import Dict, List, Union, Optional, Tuple - -from keras.utils import multi_gpu_model +from typing import Any, Dict, List, Union, Optional, Text, Tuple from sleap.instance import LabeledFrame from sleap.io.dataset import Labels @@ -34,13 +33,235 @@ from sleap.nn.transform import DataTransform from sleap.nn.datagen import merge_boxes_with_overlap_and_padding -from sleap.nn.loadmodel import load_model, get_model_data, get_model_skeleton from sleap.nn.peakfinding import find_all_peaks, find_all_single_peaks from sleap.nn.peakfinding_tf import peak_tf_inference -from sleap.nn.peakmatching import match_single_peaks_all, match_peaks_paf, match_peaks_paf_par, instances_nms +from sleap.nn.peakmatching import ( + match_single_peaks_all, + match_peaks_paf, + match_peaks_paf_par, + instances_nms, +) from sleap.nn.util import batch, batch_count, save_visual_outputs -OVERLAPPING_INSTANCES_NMS = True + +@attr.s(auto_attribs=True) +class InferenceModel: + """This class provides convenience metadata and methods for running inference from a TrainingJob.""" + + job: TrainingJob + _keras_model: keras.Model = None + _model_path: Text = None + _trained_input_shape: Tuple[int] = None + _output_channels: int = None + + @property + def skeleton(self) -> Skeleton: + """Returns the skeleton associated with this model.""" + + return self.job.model.skeletons[0] + + @property + def output_type(self) -> ModelOutputType: + """Returns the output type of this model.""" + + return self.job.model.output_type + + @property + def input_scale(self) -> float: + """Returns the scale of the images that the model was trained on.""" + + return self.job.trainer.scale + + @property + def output_scale(self) -> float: + """Returns the scale of the outputs of the model relative to the original data. + + For a model trained on inputs with scale = 0.5 that outputs predictions that + are half of the size of the inputs, the output scale is 0.25. + """ + return self.input_scale * self.job.model.output_scale + + @property + def output_relative_scale(self) -> float: + """Returns the scale of the outputs relative to the scaled inputs. + + This differs from output_scale in that it is the scaling factor after + applying the input scaling. + """ + + return self.job.model.output_scale + + def compute_output_shape( + self, input_shape: Tuple[int], relative=True + ) -> Tuple[int]: + """Returns the output tensor shape for a given input shape. + + Args: + input_shape: Shape of input images in the form (height, width). + relative: If True, input_shape specifies the shape after input scaling. + + Returns: + A tuple of (height, width, channels) of the output of the model. + """ + + # TODO: Support multi-input/multi-output models. + + scaling_factor = self.output_scale + if relative: + scaling_factor = self.output_relative_scale + + output_shape = ( + int(input_shape[0] * scaling_factor), + int(input_shape[1] * scaling_factor), + self.output_channels, + ) + + return output_shape + + def load_model(self, model_path: Text = None) -> keras.Model: + """Loads a saved model from disk and caches it. + + Args: + model_path: If not provided, uses the model + paths in the training job. + + Returns: + The loaded Keras model. This model can accept any size + of inputs that are valid. + """ + + if not model_path: + # Try the best model first. + model_path = os.path.join(self.job.save_dir, self.job.best_model_filename) + + # Try the final model if that didn't exist. + if not os.path.exists(model_path): + model_path = os.path.join( + self.job.save_dir, self.job.final_model_filename + ) + + # Load from disk. + keras_model = keras.models.load_model(model_path, custom_objects={"tf": tf}) + logger.info("Loaded model: " + model_path) + + # Store the loaded model path for reference. + self._model_path = model_path + + # TODO: Multi-input/output support + # Find the original data shape from the input shape of the first input node. + self._trained_input_shape = keras_model.get_input_shape_at(0) + + # Save output channels since that should be static. + self._output_channels = keras_model.get_output_shape_at(0)[-1] + + # Create input node with undetermined height/width. + input_tensor = keras.layers.Input((None, None, self.input_channels)) + keras_model = keras.Model( + inputs=input_tensor, outputs=keras_model(input_tensor) + ) + + # Save the modified and loaded model. + self._keras_model = keras_model + + return self.keras_model + + @property + def keras_model(self) -> keras.Model: + """Returns the underlying Keras model, loading it if necessary.""" + + if self._keras_model is None: + self.load_model() + + return self._keras_model + + @property + def model_path(self) -> Text: + """Returns the path to the loaded model.""" + + if not self._model_path: + raise AttributeError( + "No model loaded. Call inference_model.load_model() first." + ) + + return self._model_path + + @property + def trained_input_shape(self) -> Tuple[int]: + """Returns the shape of the model when it was loaded.""" + + if not self._trained_input_shape: + raise AttributeError( + "No model loaded. Call inference_model.load_model() first." + ) + + return self._trained_input_shape + + @property + def output_channels(self) -> int: + """Returns the number of output channels of the model.""" + if not self._trained_input_shape: + raise AttributeError( + "No model loaded. Call inference_model.load_model() first." + ) + + return self._output_channels + + @property + def input_channels(self) -> int: + """Returns the number of channels expected for the input data.""" + + # TODO: Multi-output support + return self.trained_input_shape[-1] + + @property + def is_grayscale(self) -> bool: + """Returns True if the model expects grayscale images.""" + + return self.input_channels == 1 + + @property + def down_blocks(self): + """Returns the number of pooling steps applied during the model. + + Data needs to be of a shape divisible by the number of pooling steps. + """ + + # TODO: Replace this with an explicit calculation that takes stride sizes into account. + return self.job.model.down_blocks + + def predict( + self, + X: Union[np.ndarray, List[np.ndarray]], + batch_size: int = 32, + normalize: bool = True, + ) -> Union[np.ndarray, List[np.ndarray]]: + """Runs inference on the input data. + + This is a simple wrapper around the keras model predict function. + + Args: + X: The inputs to provide to the model. Can be different height/width as + the data it was trained on. + batch_size: Batch size to perform inference on at a time. + normalize: Applies normalization to the input data if needed + (e.g., if casting or range normalization is required). + + Returns: + The outputs of the model. + """ + + if normalize: + # TODO: Store normalization scheme in the model metadata. + if isinstance(X, np.ndarray): + if X.dtype == np.dtype("uint8"): + X = X.astype("float32") / 255.0 + elif isinstance(X, list): + for i in range(len(X)): + if X[i].dtype == np.dtype("uint8"): + X[i] = X[i].astype("float32") / 255.0 + + return self.keras_model.predict(X, batch_size=batch_size) + @attr.s(auto_attribs=True) class Predictor: @@ -67,9 +288,11 @@ class Predictor: nms_min_thresh: A threshold of non-max suppression peak finding in confidence maps. All values below this minimum threshold will be set to zero before peak finding algorithm is run. - nms_sigma: Gaussian blur is applied to confidence maps before - non-max supression peak finding occurs. This is the - standard deviation of the kernel applied to the image. + nms_kernel_size: Gaussian blur is applied to confidence maps before + non-max supression peak finding occurs. This is size of the + kernel applied to the image. + nms_sigma: For Gassian blur applied to confidence maps, this + is the standard deviation of the kernel. min_score_to_node_ratio: FIXME min_score_midpts: FIXME min_score_integral: FIXME @@ -77,27 +300,30 @@ class Predictor: with_tracking: whether to run tracking after inference flow_window: The number of frames that tracking should look back when trying to identify instances. - crop_iou_threshold: FIXME single_per_crop: FIXME output_path: the output path to save the results save_confmaps_pafs: whether to save confmaps/pafs resize_hack: whether to resize images to power of 2 """ - sleap_models: Dict[ModelOutputType, TrainingJob] = None + training_jobs: Dict[ModelOutputType, TrainingJob] = None + inference_models: Dict[ModelOutputType, InferenceModel] = attr.ib( + default=attr.Factory(dict) + ) + skeleton: Skeleton = None inference_batch_size: int = 2 read_chunk_size: int = 256 - save_frequency: int = 100 # chunks + save_frequency: int = 100 # chunks nms_min_thresh = 0.3 - nms_sigma = 3 + nms_kernel_size: int = 9 + nms_sigma: float = 3.0 min_score_to_node_ratio: float = 0.2 min_score_midpts: float = 0.05 min_score_integral: float = 0.6 add_last_edge: bool = True with_tracking: bool = False flow_window: int = 15 - crop_iou_threshold: float = .9 single_per_crop: bool = False crop_padding: int = 40 crop_growth: int = 64 @@ -105,15 +331,27 @@ class Predictor: output_path: Optional[str] = None save_confmaps_pafs: bool = False resize_hack: bool = True + pool: multiprocessing.Pool = None - _models: Dict = attr.ib(default=attr.Factory(dict)) + gpu_peak_finding: bool = True + supersample_window_size: int = 7 # must be odd + supersample_factor: float = 2 # factor to upsample cropped windows by + overlapping_instances_nms: bool = True # suppress overlapping instances - def predict(self, - input_video: Union[dict, Video], - frames: Optional[List[int]] = None, - is_async: bool = False) -> List[LabeledFrame]: - """ - Run the entire inference pipeline on an input video. + def __attrs_post_init__(self): + + # Create inference models from the TrainingJob metadata. + for model_output_type, training_job in self.training_jobs.items(): + self.inference_models[model_output_type] = InferenceModel(job=training_job) + self.inference_models[model_output_type].load_model() + + def predict( + self, + input_video: Union[dict, Video], + frames: Optional[List[int]] = None, + is_async: bool = False, + ) -> List[LabeledFrame]: + """Run the entire inference pipeline on an input video. Args: input_video: Either a `Video` object or dict that can be @@ -125,13 +363,19 @@ def predict(self, children. Returns: - list of LabeledFrame objects + A list of LabeledFrames with predicted instances. """ + # Check if we have models. + if len(self.inference_models) == 0: + logger.warning("Predictor has no model.") + raise ValueError("Predictor has no model.") + self.is_async = is_async - # Initialize parallel pool - self.pool = None if self.is_async else multiprocessing.Pool(processes=usable_cpu_count()) + # Initialize parallel pool if needed. + if not is_async and self.pool is None: + self.pool = multiprocessing.Pool(processes=usable_cpu_count()) # Fix the number of threads for OpenCV, not that we are using # anything in OpenCV that is actually multi-threaded but maybe @@ -140,46 +384,45 @@ def predict(self, logger.info(f"Predict is async: {is_async}") - # Open the video if we need it. + # Find out if the images should be grayscale from the first model. + # TODO: Unify this with input data normalization. + grayscale = list(self.inference_models.values())[0].is_grayscale - try: - input_video.get_frame(0) + # Open the video object if needed. + if isinstance(input_video, Video): vid = input_video - except AttributeError: - if isinstance(input_video, dict): - vid = Video.cattr().structure(input_video, Video) - elif isinstance(input_video, str): - vid = Video.from_filename(input_video) - else: - raise AttributeError(f"Unable to load input video: {input_video}") + elif isinstance(input_video, dict): + vid = Video.cattr().structure(input_video, Video) + elif isinstance(input_video, str): + vid = Video.from_filename(input_video, grayscale=grayscale) + else: + raise AttributeError(f"Unable to load input video: {input_video}") # List of frames to process (or entire video if not specified) frames = frames or list(range(vid.num_frames)) - - vid_h = vid.shape[1] - vid_w = vid.shape[2] - logger.info("Opened video:") logger.info(" Source: " + str(vid.backend)) logger.info(" Frames: %d" % len(frames)) - logger.info(" Frame shape: %d x %d" % (vid_h, vid_w)) - - # Check training models - if len(self.sleap_models) == 0: - logger.warning("Predictor has no model.") - raise ValueError("Predictor has no model.") + logger.info(" Frame shape (H x W): %d x %d" % (vid.height, vid.width)) # Initialize tracking - tracker = FlowShiftTracker(window=self.flow_window, verbosity=0) + if self.with_tracking: + tracker = FlowShiftTracker(window=self.flow_window, verbosity=0) + + if self.output_path: + # Delete the output file if it exists already + if os.path.exists(self.output_path): + os.unlink(self.output_path) + logger.warning("Deleted existing output: " + self.output_path) - # Delete the output file if it exists already - if os.path.exists(self.output_path): - os.unlink(self.output_path) + # Create output directory if it doesn't exist + if not os.path.exists(self.output_path): + os.makedirs(os.path.dirname(self.output_path), exist_ok=True) + logger.info("Output path: " + self.output_path) # Process chunk-by-chunk! t0_start = time() - predicted_frames: List[LabeledFrame] = [] - + predicted_frames = [] num_chunks = batch_count(frames, self.read_chunk_size) logger.info("Number of chunks for process: %d" % (num_chunks)) @@ -204,44 +447,23 @@ def predict(self, # Read the next batch of images t0 = time() - mov_full = vid[frames_idx] - logger.info(" Read %d frames [%.1fs]" % (len(mov_full), time() - t0)) + imgs_full = vid[frames_idx] + logger.info(" Read %d frames [%.1fs]" % (len(imgs_full), time() - t0)) # Transform images (crop or scale) t0 = time() - if ModelOutputType.CENTROIDS in self.sleap_models: - # Use centroid predictions to get subchunks of crops + if ModelOutputType.CENTROIDS in self.inference_models: + # Use centroid predictions to get subchunks of crops. subchunks_to_process = self.centroid_crop_inference( - mov_full, frames_idx, - iou_threshold=self.crop_iou_threshold) + imgs_full, frames_idx + ) else: - # Scale without centroid cropping - - # Get the scale that was used when training models - model_data = get_model_data(self.sleap_models, [ModelOutputType.CONFIDENCE_MAP]) - scale = model_data["scale"] - - # Determine scaled image size - scale_to = (int(vid.height//(1/scale)), int(vid.width//(1/scale))) - - # FIXME: Adjust to appropriate power of 2 - # It would be better to pad image to a usable size, since - # the resize could affect aspect ratio. - if self.resize_hack: - scale_to = (scale_to[0]//8*8, scale_to[1]//8*8) - # Create transform object - transform = DataTransform( - frame_idxs = frames_idx, - scale = model_data["multiscale"]) - - # Scale if target doesn't match current size - mov = transform.scale_to(mov_full, target_size=scale_to) - - subchunks_to_process = [(mov, transform)] + transform = DataTransform(frame_idxs=frames_idx) + subchunks_to_process = [(imgs_full, transform)] logger.info(" Transformed images [%.1fs]" % (time() - t0)) @@ -259,32 +481,32 @@ def predict(self, subchunk_results = [] - for subchunk_mov, subchunk_transform in subchunks_to_process: + for subchunk_imgs_full, subchunk_transform in subchunks_to_process: logger.info(f" Running inference for subchunk:") - logger.info(f" Shape: {subchunk_mov.shape}") - logger.info(f" Prediction Scale: {subchunk_transform.scale}") + logger.info(f" Shape: {subchunk_imgs_full.shape}") + logger.info(f" Scale: {subchunk_transform.scale}") - if ModelOutputType.PART_AFFINITY_FIELD not in self.sleap_models: + if ModelOutputType.PART_AFFINITY_FIELD not in self.inference_models: # Pipeline for predicting a single animal in a frame # This uses only confidence maps logger.warning("No PAF model! Running in SINGLE INSTANCE mode.") subchunk_lfs = self.single_instance_inference( - subchunk_mov, - subchunk_transform, - vid) + subchunk_imgs_full, subchunk_transform, vid + ) else: # Pipeline for predicting multiple animals in a frame # This uses confidence maps and part affinity fields subchunk_lfs = self.multi_instance_inference( - subchunk_mov, - subchunk_transform, - vid) + subchunk_imgs_full, subchunk_transform, vid + ) - logger.info(f" Subchunk frames with instances found: {len(subchunk_lfs)}") + logger.info( + f" Subchunk frames with instances found: {len(subchunk_lfs)}" + ) subchunk_results.append(subchunk_lfs) @@ -306,9 +528,13 @@ def predict(self, predicted_frames_chunk = [] for subchunk_frames in subchunk_results: predicted_frames_chunk.extend(subchunk_frames) - predicted_frames_chunk = LabeledFrame.merge_frames(predicted_frames_chunk, video=vid) + predicted_frames_chunk = LabeledFrame.merge_frames( + predicted_frames_chunk, video=vid + ) - logger.info(f" Instances found on {len(predicted_frames_chunk)} out of {len(mov_full)} frames.") + logger.info( + f" Instances found on {len(predicted_frames_chunk)} out of {len(imgs_full)} frames." + ) if len(predicted_frames_chunk): @@ -318,7 +544,7 @@ def predict(self, # Track if self.with_tracking and len(predicted_frames_chunk): t0 = time() - tracker.process(mov_full, predicted_frames_chunk) + tracker.process(imgs_full, predicted_frames_chunk) logger.info(" Tracked IDs via flow shift [%.1fs]" % (time() - t0)) # Save @@ -327,23 +553,30 @@ def predict(self, if chunk % self.save_frequency == 0 or chunk == (num_chunks - 1): t0 = time() - # FIXME: We are re-writing the whole output each time, this is dumb. + # TODO: We are re-writing the whole output each time, this is dumb. # We should save in chunks then combine at the end. labels = Labels(labeled_frames=predicted_frames) if self.output_path is not None: - if self.output_path.endswith('json'): - Labels.save_json(labels, filename=self.output_path, compress=True) + if self.output_path.endswith("json"): + Labels.save_json( + labels, filename=self.output_path, compress=True + ) else: Labels.save_hdf5(labels, filename=self.output_path) - logger.info(" Saved to: %s [%.1fs]" % (self.output_path, time() - t0)) + logger.info( + " Saved to: %s [%.1fs]" % (self.output_path, time() - t0) + ) elapsed = time() - t0_chunk total_elapsed = time() - t0_start fps = len(predicted_frames) / total_elapsed frames_left = len(frames) - len(predicted_frames) eta = (frames_left / fps) if fps > 0 else 0 - logger.info(" Finished chunk [%.1fs / %.1f FPS / ETA: %.1f min]" % (elapsed, fps, eta / 60)) + logger.info( + " Finished chunk [%.1fs / %.1f FPS / ETA: %.1f min]" + % (elapsed, fps, eta / 60) + ) sys.stdout.flush() @@ -380,21 +613,22 @@ def predict_async(self, *args, **kwargs) -> Tuple[Pool, AsyncResult]: # unstructure input_video since it won't pickle kwargs["input_video"] = Video.cattr().unstructure(kwargs["input_video"]) - pool = Pool(processes=1) - result = pool.apply_async(self.predict, args=args, kwds=kwargs) + if self.pool is None: + self.pool = Pool(processes=1) + result = self.pool.apply_async(self.predict, args=args, kwds=kwargs) # Tell the pool to accept no new tasks - pool.close() - - return pool, result + # pool.close() - # Methods for running inferring on components of pipeline + return result - def centroid_crop_inference(self, - imgs: np.ndarray, - frames_idx: List[int], - iou_threshold: float=.9) \ - -> List[Tuple[np.ndarray, DataTransform]]: + def centroid_crop_inference( + self, + imgs: np.ndarray, + frames_idx: List[int], + box_size: int = None, + do_merge: bool = True, + ) -> List[Tuple[np.ndarray, DataTransform]]: """ Takes stack of images and runs centroid inference to get crops. @@ -408,52 +642,66 @@ def centroid_crop_inference(self, which allows us to merge overlapping crops into larger crops. """ - crop_within = (imgs.shape[1]//8*8, imgs.shape[2]//8*8) + # Get inference models with metadata. + centroid_model = self.inference_models[ModelOutputType.CENTROIDS] + cm_model = self.inference_models[ModelOutputType.CONFIDENCE_MAP] - # Fetch centroid model (uses cache if already loaded) + logger.info(" Performing centroid cropping.") - model_package = self.fetch_model( - input_size = None, - output_types = [ModelOutputType.CENTROIDS]) + # TODO: Replace this calculation when model-specific divisibility calculation implemented. + divisor = 2 ** centroid_model.down_blocks + crop_within = ( + (imgs.shape[1] // divisor) * divisor, + (imgs.shape[2] // divisor) * divisor, + ) + logger.info(f" crop_within: {crop_within}") # Create transform - # This lets us scale the images before we predict centroids, # and will also let us map the points on the scaled image to # points on the original images so we can crop original images. - centroid_transform = DataTransform() + target_shape = ( + int(imgs.shape[1] * centroid_model.input_scale), + int(imgs.shape[2] * centroid_model.input_scale), + ) - # Scale to match input size of trained centroid model - # Usually this will be 1/4-scale of original images + # Scale to match input size of trained centroid model. + centroid_imgs_scaled = centroid_transform.scale_to( + imgs=imgs, target_size=target_shape + ) - centroid_imgs_scaled = \ - centroid_transform.scale_to( - imgs=imgs, - target_size=model_package["model"].input_shape[1:3]) - - # Predict centroid confidence maps, then find peaks - - centroid_confmaps = model_package["model"].predict(centroid_imgs_scaled.astype("float32") / 255, - batch_size=self.inference_batch_size) - - peaks, peak_vals = find_all_peaks(centroid_confmaps, - min_thresh=self.nms_min_thresh, - sigma=self.nms_sigma) - - # Get training bounding box size to determine (min) centroid crop size - - crop_model_package = self.fetch_model( - input_size = None, - output_types = [ModelOutputType.CONFIDENCE_MAP]) - crop_size = crop_model_package["bounding_box_size"] - bb_half = (crop_size + self.crop_padding)//2 + # Predict centroid confidence maps, then find peaks. + t0 = time() + centroid_confmaps = centroid_model.predict( + centroid_imgs_scaled, batch_size=self.inference_batch_size + ) + + peaks, peak_vals = find_all_peaks( + centroid_confmaps, min_thresh=self.nms_min_thresh, sigma=self.nms_sigma + ) + + elapsed = time() - t0 + total_peaks = sum([len(frame_peaks[0]) for frame_peaks in peaks]) + logger.info( + f" Found {total_peaks} centroid peaks ({total_peaks / len(peaks):.2f} centroids/frame) [{elapsed:.2f}s]." + ) + + if box_size is None: + # Get training bounding box size to determine (min) centroid crop size. + # TODO: fix this to use a stored value or move this logic elsewhere + crop_size = int( + max(cm_model.trained_input_shape[1:3]) // cm_model.input_scale + ) + bb_half = crop_size // 2 + # bb_half = (crop_size + self.crop_padding) // 2 + else: + bb_half = box_size // 2 - logger.info(f" Centroid crop box size: {bb_half*2}") - - all_boxes = dict() + logger.info(f" Crop box size: {bb_half * 2}") # Iterate over each frame to filter bounding boxes + all_boxes = dict() for frame_i, (frame_peaks, frame_peak_vals) in enumerate(zip(peaks, peak_vals)): # If we found centroids on this frame... @@ -465,45 +713,68 @@ def centroid_crop_inference(self, boxes = [] for peak_i in range(frame_peaks[0].shape[0]): + # Rescale peak back onto full-sized image - peak_x = int(frame_peaks[0][peak_i][0] / centroid_transform.scale) - peak_y = int(frame_peaks[0][peak_i][1] / centroid_transform.scale) + peak_x = int( + frame_peaks[0][peak_i][0] / centroid_model.output_scale + ) + peak_y = int( + frame_peaks[0][peak_i][1] / centroid_model.output_scale + ) - boxes.append((peak_x-bb_half, peak_y-bb_half, - peak_x+bb_half, peak_y+bb_half)) + boxes.append( + ( + peak_x - bb_half, + peak_y - bb_half, + peak_x + bb_half, + peak_y + bb_half, + ) + ) - # Merge overlapping boxes and pad to multiple of crop size - merged_boxes = merge_boxes_with_overlap_and_padding( - boxes=boxes, - pad_factor_box=(self.crop_growth, self.crop_growth), - within=crop_within) + if do_merge: + # Merge overlapping boxes and pad to multiple of crop size + merged_boxes = merge_boxes_with_overlap_and_padding( + boxes=boxes, + pad_factor_box=(self.crop_growth, self.crop_growth), + within=crop_within, + ) + + else: + # Just return the boxes centered around each centroid. + # Note that these aren't guaranteed to be within the + # image bounds, so take care if using these to crop. + merged_boxes = boxes # Keep track of all boxes, grouped by size and frame idx for box in merged_boxes: - box_size = (box[2]-box[0], box[3]-box[1]) + merged_box_size = (box[2] - box[0], box[3] - box[1]) + + if merged_box_size not in all_boxes: + all_boxes[merged_box_size] = dict() + logger.info(f" Found box size: {merged_box_size}") + + if frame_i not in all_boxes[merged_box_size]: + all_boxes[merged_box_size][frame_i] = [] - if box_size not in all_boxes: - all_boxes[box_size] = dict() - if frame_i not in all_boxes[box_size]: - all_boxes[box_size][frame_i] = [] + all_boxes[merged_box_size][frame_i].append(box) - all_boxes[box_size][frame_i].append(box) + logger.info(f" Found {len(all_boxes)} box sizes after merging.") subchunks = [] # Check if we found any boxes for this chunk of frames if len(all_boxes): - model_data = get_model_data(self.sleap_models, [ModelOutputType.CONFIDENCE_MAP]) # We'll make a "subchunk" for each crop size for crop_size in all_boxes: - if crop_size[0] >= 1024: - logger.info(f" Skipping subchunk for size {crop_size}, would have {len(all_boxes[crop_size])} crops.") - for debug_frame_idx in all_boxes[crop_size].keys(): - print(f" frame {frames_idx[debug_frame_idx]}: {all_boxes[crop_size][debug_frame_idx]}") - continue + # TODO: Look into this edge case? + # if crop_size[0] >= 1024: + # logger.info(f" Skipping subchunk for size {crop_size}, would have {len(all_boxes[crop_size])} crops.") + # for debug_frame_idx in all_boxes[crop_size].keys(): + # print(f" frame {frames_idx[debug_frame_idx]}: {all_boxes[crop_size][debug_frame_idx]}") + # continue # Make list of all boxes and corresponding img index. subchunk_idxs = [] @@ -511,12 +782,12 @@ def centroid_crop_inference(self, for frame_i, frame_boxes in all_boxes[crop_size].items(): subchunk_boxes.extend(frame_boxes) - subchunk_idxs.extend( [frame_i] * len(frame_boxes) ) + subchunk_idxs.extend([frame_i] * len(frame_boxes)) + # TODO: This should probably be in the main loop # Create transform object - transform = DataTransform( - frame_idxs = frames_idx, - scale = model_data["multiscale"]) + # transform = DataTransform(frame_idxs=frames_idx, scale=cm_model.output_relative_scale) + transform = DataTransform(frame_idxs=frames_idx) # Do the cropping imgs_cropped = transform.crop(imgs, subchunk_boxes, subchunk_idxs) @@ -524,7 +795,9 @@ def centroid_crop_inference(self, # Add subchunk subchunks.append((imgs_cropped, transform)) - logger.info(f" Subchunk for size {crop_size} has {len(imgs_cropped)} crops.") + logger.info( + f" Subchunk for size {crop_size} has {len(imgs_cropped)} crops." + ) else: logger.info(" No centroids found so done with this chunk.") @@ -532,208 +805,319 @@ def centroid_crop_inference(self, return subchunks def single_instance_inference(self, imgs, transform, video) -> List[LabeledFrame]: - """Run the single instance pipeline for a stack of images.""" + """Run the single instance pipeline for a stack of images. - # Get confmap model for this image size - model_package = self.fetch_model( - input_size = imgs.shape[1:], - output_types = [ModelOutputType.CONFIDENCE_MAP]) + Args: + imgs: Subchunk of images to process. + transform: DataTransform object tracking input transformations. + video: Video object for building LabeledFrames with correct reference to source. - # Run inference - t0 = time() + Returns: + A list of LabeledFrames with predicted points. + """ + + # Get confmap inference model. + cm_model = self.inference_models[ModelOutputType.CONFIDENCE_MAP] - confmaps = model_package["model"].predict(imgs.astype("float32") / 255, batch_size=self.inference_batch_size) - logger.info( " Inferred confmaps [%.1fs]" % (time() - t0)) + # Scale to match input size of trained model. + # Images are expected to be at full resolution, but may be cropped. + assert transform.scale == 1.0 + target_shape = ( + int(imgs.shape[1] * cm_model.input_scale), + int(imgs.shape[2] * cm_model.input_scale), + ) + imgs_scaled = transform.scale_to(imgs=imgs, target_size=target_shape) + + # TODO: Adjust for divisibility + # divisor = 2 ** cm_model.down_blocks + # crop_within = ((imgs.shape[1] // divisor) * divisor, (imgs.shape[2] // divisor) * divisor) + + # Run inference. + t0 = time() + confmaps = cm_model.predict(imgs_scaled, batch_size=self.inference_batch_size) + logger.info(" Inferred confmaps [%.1fs]" % (time() - t0)) logger.info(f" confmaps: shape={confmaps.shape}, ptp={np.ptp(confmaps)}") t0 = time() + # TODO: Move this to GPU and add subpixel refinement. # Use single highest peak in channel corresponding node - points_arrays = find_all_single_peaks(confmaps, - min_thresh=self.nms_min_thresh) + points_arrays = find_all_single_peaks(confmaps, min_thresh=self.nms_min_thresh) + # Adjust for multi-scale such that the points are at the scale of the transform. + points_arrays = [pts / cm_model.output_relative_scale for pts in points_arrays] + + # Create labeled frames and predicted instances from the points. predicted_frames_chunk = match_single_peaks_all( - points_arrays = points_arrays, - skeleton = model_package["skeleton"], - transform = transform, - video = video) + points_arrays=points_arrays, + skeleton=cm_model.skeleton, + transform=transform, + video=video, + ) logger.info(" Used highest peaks to create instances [%.1fs]" % (time() - t0)) # Save confmaps if self.output_path is not None and self.save_confmaps_pafs: - save_visual_outputs( - output_path = self.output_path, - data = dict(confmaps=confmaps, box=imgs)) + raise NotImplementedError( + "Not saving confmaps/pafs because feature currently not working." + ) + # Disable save_confmaps_pafs since not currently working. + # The problem is that we can't put data for different crop sizes + # all into a single h5 datasource. It's now possible to view live + # predicted confmap and paf in the gui, so this isn't high priority. + # save_visual_outputs( + # output_path = self.output_path, + # data = dict(confmaps=confmaps, box=imgs)) return predicted_frames_chunk def multi_instance_inference(self, imgs, transform, video) -> List[LabeledFrame]: - """ - Run the multi-instance inference pipeline for a stack of images. + """Run the multi-instance inference pipeline for a stack of images. + + Args: + imgs: Subchunk of images to process. + transform: DataTransform object tracking input transformations. + video: Video object for building LabeledFrames with correct reference to source. + + Returns: + A list of LabeledFrames with predicted points. """ # Load appropriate models as needed - conf_model = self.fetch_model( - input_size = imgs.shape[1:], - output_types = [ModelOutputType.CONFIDENCE_MAP]) - - paf_model = self.fetch_model( - input_size = imgs.shape[1:], - output_types = [ModelOutputType.PART_AFFINITY_FIELD]) + cm_model = self.inference_models[ModelOutputType.CONFIDENCE_MAP] + paf_model = self.inference_models[ModelOutputType.PART_AFFINITY_FIELD] # Find peaks t0 = time() - peaks, peak_vals, confmaps = \ - peak_tf_inference( - model = conf_model["model"], - data = imgs.astype("float32")/255, - min_thresh=self.nms_min_thresh, - downsample_factor=int(1/paf_model["multiscale"]), - upsample_factor=int(1/conf_model["multiscale"]), - return_confmaps=self.save_confmaps_pafs - ) + # Scale to match input resolution of model. + # Images are expected to be at full resolution, but may be cropped. + assert transform.scale == 1.0 + cm_target_shape = ( + int(imgs.shape[1] * cm_model.input_scale), + int(imgs.shape[2] * cm_model.input_scale), + ) + imgs_scaled = transform.scale_to(imgs=imgs, target_size=cm_target_shape) + if imgs_scaled.dtype == np.dtype("uint8"): # TODO: Unify normalization. + imgs_scaled = imgs_scaled.astype("float32") / 255.0 + + # TODO: Unfuck this whole workflow + if self.gpu_peak_finding: + confmaps_shape = cm_model.compute_output_shape( + (imgs_scaled.shape[1], imgs_scaled.shape[2]) + ) + peaks, peak_vals, confmaps = peak_tf_inference( + model=cm_model.keras_model, + confmaps_shape=confmaps_shape, + data=imgs_scaled, + min_thresh=self.nms_min_thresh, + gaussian_size=self.nms_kernel_size, + gaussian_sigma=self.nms_sigma, + upsample_factor=int(self.supersample_factor / cm_model.output_scale), + win_size=self.supersample_window_size, + return_confmaps=self.save_confmaps_pafs, + batch_size=self.inference_batch_size, + ) - transform.scale = transform.scale * paf_model["multiscale"] + else: + confmaps = cm_model.predict( + imgs_scaled, batch_size=self.inference_batch_size + ) + peaks, peak_vals = find_all_peaks( + confmaps, min_thresh=self.nms_min_thresh, sigma=self.nms_sigma + ) + + # # Undo just the scaling so we're back to full resolution, but possibly cropped. + for t in range(len(peaks)): # frames + for c in range(len(peaks[t])): # channels + peaks[t][c] /= cm_model.output_scale + + # Peaks should be at (refined) full resolution now. + # Keep track of scale adjustment. + transform.scale = 1.0 + + elapsed = time() - t0 + total_peaks = sum( + [ + len(channel_peaks) + for frame_peaks in peaks + for channel_peaks in frame_peaks + ] + ) + logger.info( + f" Found {total_peaks} peaks ({total_peaks / len(imgs):.2f} peaks/frame) [{elapsed:.2f}s]." + ) + # logger.info(f" peaks: {peaks}") + + # Scale to match input resolution of model. + # Images are expected to be at full resolution, but may be cropped. + paf_target_shape = ( + int(imgs.shape[1] * paf_model.input_scale), + int(imgs.shape[2] * paf_model.input_scale), + ) + if (imgs_scaled.shape[1] == paf_target_shape[0]) and ( + imgs_scaled.shape[2] == paf_target_shape[1] + ): + # No need to scale again if we're already there, so just adjust the stored scale + transform.scale = paf_model.input_scale - logger.info(" Inferred confmaps and found-peaks (gpu) [%.1fs]" % (time() - t0)) - logger.info(f" peaks: {len(peaks)}") + else: + # Adjust scale from full resolution images (avoiding possible resizing up from confmaps input scale) + imgs_scaled = transform.scale_to(imgs=imgs, target_size=paf_target_shape) # Infer pafs t0 = time() - pafs = paf_model["model"].predict(imgs.astype("float32") / 255, batch_size=self.inference_batch_size) - - logger.info( " Inferred PAFs [%.1fs]" % (time() - t0)) + pafs = paf_model.predict(imgs_scaled, batch_size=self.inference_batch_size) + logger.info(" Inferred PAFs [%.1fs]" % (time() - t0)) logger.info(f" pafs: shape={pafs.shape}, ptp={np.ptp(pafs)}") + # Adjust points to the paf output scale so we can invert later (should not incur loss of precision) + # TODO: Check precision + for t in range(len(peaks)): # frames + for c in range(len(peaks[t])): # channels + peaks[t][c] *= paf_model.output_scale + transform.scale = paf_model.output_scale + # Determine whether to use serial or parallel version of peak-finding # Use the serial version is we're already running in a thread pool - match_peaks_function = match_peaks_paf_par if not self.is_async else match_peaks_paf + match_peaks_function = ( + match_peaks_paf_par if not self.is_async else match_peaks_paf + ) # Match peaks via PAFs t0 = time() - predicted_frames_chunk = match_peaks_function( - peaks, peak_vals, pafs, conf_model["skeleton"], - transform=transform, video=video, - min_score_to_node_ratio=self.min_score_to_node_ratio, - min_score_midpts=self.min_score_midpts, - min_score_integral=self.min_score_integral, - add_last_edge=self.add_last_edge, - single_per_crop=self.single_per_crop, - pool=self.pool) - + peaks, + peak_vals, + pafs, + paf_model.skeleton, + transform=transform, + video=video, + min_score_to_node_ratio=self.min_score_to_node_ratio, + min_score_midpts=self.min_score_midpts, + min_score_integral=self.min_score_integral, + add_last_edge=self.add_last_edge, + single_per_crop=self.single_per_crop, + pool=self.pool, + ) + + total_instances = sum( + [len(labeled_frame) for labeled_frame in predicted_frames_chunk] + ) logger.info(" Matched peaks via PAFs [%.1fs]" % (time() - t0)) + logger.info( + f" Found {total_instances} instances ({total_instances / len(imgs):.2f} instances/frame)" + ) # Remove overlapping predicted instances - if OVERLAPPING_INSTANCES_NMS: + if self.overlapping_instances_nms: t0 = clock() for lf in predicted_frames_chunk: n = len(lf.instances) instances_nms(lf.instances) if len(lf.instances) < n: - logger.info(f" Removed {n-len(lf.instances)} overlapping instance(s) from frame {lf.frame_idx}") + logger.info( + f" Removed {n-len(lf.instances)} overlapping instance(s) from frame {lf.frame_idx}" + ) logger.info(" Instance NMS [%.1fs]" % (clock() - t0)) # Save confmaps and pafs if self.output_path is not None and self.save_confmaps_pafs: - save_visual_outputs( - output_path = self.output_path, - data = dict(confmaps=confmaps, pafs=pafs, - frame_idxs=transform.frame_idxs, bounds=transform.bounding_boxes)) + raise NotImplementedError( + "Not saving confmaps/pafs because feature currently not working." + ) + # Disable save_confmaps_pafs since not currently working. + # The problem is that we can't put data for different crop sizes + # all into a single h5 datasource. It's now possible to view live + # predicted confmap and paf in the gui, so this isn't high priority. + # save_visual_outputs( + # output_path = self.output_path, + # data = dict(confmaps=confmaps, pafs=pafs, + # frame_idxs=transform.frame_idxs, bounds=transform.bounding_boxes)) return predicted_frames_chunk - def fetch_model(self, - input_size: tuple, - output_types: List[ModelOutputType]) -> dict: - """Loads and returns keras Model with caching.""" - - key = (input_size, tuple(output_types)) - - if key not in self._models: - - # Load model - - keras_model = load_model(self.sleap_models, input_size, output_types) - first_sleap_model = self.sleap_models[output_types[0]] - model_data = get_model_data(self.sleap_models, output_types) - skeleton = get_model_skeleton(self.sleap_models, output_types) - - # logger.info(f"Model multiscale: {model_data['multiscale']}") - - # If no input size was specified, then use the input size - # from original trained model. - - if input_size is None: - input_size = keras_model.input_shape[1:] - - # Get the size of the bounding box from training data - # (or the size of crop that model was trained on if the - # bounding box size wasn't set). - - if first_sleap_model.trainer.instance_crop: - bounding_box_size = \ - first_sleap_model.trainer.bounding_box_size or keras_model.input_shape[1] - else: - bounding_box_size = None - - # Cache the model so we don't have to load it next time - - self._models[key] = dict( - model=keras_model, - skeleton=model_data["skeleton"], - multiscale=model_data["multiscale"], - bounding_box_size=bounding_box_size - ) - - # Return the keras Model - return self._models[key] - def main(): - def frame_list(frame_str: str): # Handle ranges of frames. Must be of the form "1-200" - if '-' in frame_str: - min_max = frame_str.split('-') + if "-" in frame_str: + min_max = frame_str.split("-") min_frame = int(min_max[0]) max_frame = int(min_max[1]) - return list(range(min_frame, max_frame+1)) + return list(range(min_frame, max_frame + 1)) return [int(x) for x in frame_str.split(",")] if len(frame_str) else None parser = argparse.ArgumentParser() parser.add_argument("data_path", help="Path to video file") - parser.add_argument("-m", "--model", dest='models', action='append', - help="Path to saved model (confmaps, pafs, ...) JSON. " - "Multiple models can be specified, each preceded by " - "--model. Confmap and PAF models are required.", - required=True) - parser.add_argument('--resize-input', dest='resize_input', action='store_const', - const=True, default=False, - help='resize the input layer to image size (default False)') - parser.add_argument('--with-tracking', dest='with_tracking', action='store_const', - const=True, default=False, - help='just visualize predicted confmaps/pafs (default False)') - parser.add_argument('--frames', type=frame_list, default="", - help='list of frames to predict. Either comma separated list (e.g. 1,2,3) or ' - 'a range separated by hyphen (e.g. 1-3). (default is entire video)') - parser.add_argument('-o', '--output', type=str, default=None, - help='The output filename to use for the predicted data.') - parser.add_argument('--out_format', choices=['hdf5', 'json'], help='The format to use for' - ' the output file. Either hdf5 or json. hdf5 is the default.', - default='hdf5') - parser.add_argument('--save-confmaps-pafs', dest='save_confmaps_pafs', action='store_const', - const=True, default=False, - help='Whether to save the confidence maps or pafs') - parser.add_argument('-v', '--verbose', help='Increase logging output verbosity.', action="store_true") + parser.add_argument( + "-m", + "--model", + dest="models", + action="append", + help="Path to saved model (confmaps, pafs, ...) JSON. " + "Multiple models can be specified, each preceded by " + "--model. Confmap and PAF models are required.", + required=True, + ) + parser.add_argument( + "--resize-input", + dest="resize_input", + action="store_const", + const=True, + default=False, + help="resize the input layer to image size (default False)", + ) + parser.add_argument( + "--with-tracking", + dest="with_tracking", + action="store_const", + const=True, + default=False, + help="just visualize predicted confmaps/pafs (default False)", + ) + parser.add_argument( + "--frames", + type=frame_list, + default="", + help="list of frames to predict. Either comma separated list (e.g. 1,2,3) or " + "a range separated by hyphen (e.g. 1-3). (default is entire video)", + ) + parser.add_argument( + "-o", + "--output", + type=str, + default=None, + help="The output filename to use for the predicted data.", + ) + parser.add_argument( + "--out_format", + choices=["hdf5", "json"], + help="The format to use for" + " the output file. Either hdf5 or json. hdf5 is the default.", + default="hdf5", + ) + parser.add_argument( + "--save-confmaps-pafs", + dest="save_confmaps_pafs", + action="store_const", + const=True, + default=False, + help="Whether to save the confidence maps or pafs", + ) + parser.add_argument( + "-v", + "--verbose", + help="Increase logging output verbosity.", + action="store_true", + ) args = parser.parse_args() - if args.out_format == 'json': + if args.out_format == "json": output_suffix = ".predictions.json" else: output_suffix = ".predictions.h5" @@ -766,14 +1150,16 @@ def frame_list(frame_str: str): img_shape = None # Create a predictor to do the work. - predictor = Predictor(sleap_models=sleap_models, - output_path=save_path, - save_confmaps_pafs=args.save_confmaps_pafs, - with_tracking=args.with_tracking) + predictor = Predictor( + training_jobs=sleap_models, + output_path=save_path, + save_confmaps_pafs=args.save_confmaps_pafs, + with_tracking=args.with_tracking, + ) # Run the inference pipeline return predictor.predict(input_video=data_path, frames=frames) if __name__ == "__main__": - main() + main() diff --git a/sleap/nn/loadmodel.py b/sleap/nn/loadmodel.py deleted file mode 100644 index d65e6efc9..000000000 --- a/sleap/nn/loadmodel.py +++ /dev/null @@ -1,158 +0,0 @@ -import logging -logger = logging.getLogger(__name__) - -import numpy as np -import keras - -from time import time, clock -from typing import Dict, List, Union, Optional, Tuple - -import tensorflow as tf -import keras -# keras = tf.keras - -from sleap.skeleton import Skeleton -from sleap.nn.model import ModelOutputType -from sleap.nn.training import TrainingJob - -def load_model( - sleap_models: List[TrainingJob], - input_size: Optional[tuple], - output_types: List[ModelOutputType]) -> keras.Model: - """ - Load keras Model for specified input size and output types. - - Supports centroids, confmaps, and pafs. If output type includes - confmaps and pafs then we'll combine these into a single model. - - Arguments: - sleap_models: dict of the TrainingJobs where we can find models. - input_size: (h, w, c) tuple; if None, don't resize input layer - output_types: list of ModelOutputTypes - Returns: - keras Model - """ - - if ModelOutputType.CENTROIDS in output_types: - # Load centroid model - keras_model = load_model_from_job(sleap_models[ModelOutputType.CENTROIDS]) - - logger.info(f"Loaded centroid model trained on shape {keras_model.input_shape}") - - else: - # Load model for confmaps or pafs or both - - models = [] - - new_input_layer = tf.keras.layers.Input(input_size) if input_size is not None else None - - for output_type in output_types: - - # Load the model - job = sleap_models[output_type] - model = load_model_from_job(job) - - logger.info(f"Loaded {output_type} model trained on shape {model.input_shape}") - - # Get input layer if we didn't create one for a specified size - if new_input_layer is None: - new_input_layer = model.input - - # Resize input layer - model.layers.pop(0) - model = model(new_input_layer) - - logger.info(f" Resized input layer to {input_size}") - - # Add to list of models we've just loaded - models.append(model) - - if len(models) == 1: - keras_model = tf.keras.Model(new_input_layer, models[0]) - else: - # Merge multiple models into single model - keras_model = tf.keras.Model(new_input_layer, models) - - logger.info(f" Merged {len(models)} into single model") - - # keras_model = convert_to_gpu_model(keras_model) - - return keras_model - -def get_model_data( - sleap_models: Dict[ModelOutputType,TrainingJob], - output_types: List[ModelOutputType]) -> Dict: - - model_type = output_types[0] - job = sleap_models[model_type] - - # Model input is scaled by to get output - try: - asym = job.model.backbone.down_blocks - job.model.backbone.up_blocks - multiscale = 1/(2**asym) - except: - multiscale = 1 - - model_properties = dict( - skeleton = job.model.skeletons[0], - scale = job.trainer.scale, - multiscale = multiscale) - - return model_properties - -def get_model_skeleton(sleap_models, output_types) -> Skeleton: - - skeleton = get_model_data(sleap_models, output_types)["skeleton"] - - if skeleton is None: - logger.warning("Predictor has no skeleton.") - raise ValueError("Predictor has no skeleton.") - - return skeleton - -def load_model_from_job(job: TrainingJob) -> keras.Model: - """Load keras Model from a specific TrainingJob.""" - - # init = tf.global_variables_initializer() - # keras.backend.get_session().run(init) - # logger.info("Initialized TF global variables.") - - # Load model from TrainingJob data - keras_model = tf.keras.models.load_model(job_model_path(job)) - - # Rename to prevent layer naming conflict - name_prefix = f"{job.model.output_type}_" - keras_model._name = name_prefix + keras_model.name - for i in range(len(keras_model.layers)): - keras_model.layers[i]._name = name_prefix + keras_model.layers[i].name - - return keras_model - -def job_model_path(job: TrainingJob) -> str: - import os - return os.path.join(job.save_dir, job.best_model_filename) - -def get_available_gpus(): - """ - Get the list of available GPUs - - Returns: - List of available GPU device names - """ - - from tensorflow.python.client import device_lib - local_device_protos = device_lib.list_local_devices() - return [x.name for x in local_device_protos if x.device_type == 'GPU'] - -def convert_to_gpu_model(model: keras.Model) -> keras.Model: - gpu_list = get_available_gpus() - - if len(gpu_list) == 0: - logger.warn('No GPU devices, this is going to be really slow, something is wrong, dont do this!!!') - else: - logger.info(f'Detected {len(gpu_list)} GPU(s) for inference') - - if len(gpu_list) > 1: - model = keras.util.multi_gpu_model(model, gpus=len(gpu_list)) - - return model \ No newline at end of file diff --git a/sleap/nn/model.py b/sleap/nn/model.py index 2afa979e9..f89826934 100644 --- a/sleap/nn/model.py +++ b/sleap/nn/model.py @@ -30,6 +30,7 @@ class ModelOutputType(Enum): by Cao et al. """ + CONFIDENCE_MAP = 0 PART_AFFINITY_FIELD = 1 CENTROIDS = 2 @@ -43,7 +44,9 @@ def __str__(self): return "centroids" else: # This shouldn't ever happen I don't think. - raise NotImplementedError(f"__str__ not implemented for ModelOutputType={self}") + raise NotImplementedError( + f"__str__ not implemented for ModelOutputType={self}" + ) @attr.s(auto_attribs=True) @@ -66,25 +69,30 @@ class Model: not set this value. """ + output_type: ModelOutputType - backbone: Union[LeapCNN, UNet, StackedUNet, StackedHourglass] + backbone: BackboneType skeletons: Union[None, List[Skeleton]] = None backbone_name: str = None def __attrs_post_init__(self): if not isinstance(self.backbone, tuple(available_archs)): - raise ValueError(f"backbone ({self.backbone}) is not " - f"in available architectures ({available_archs})") + raise ValueError( + f"backbone ({self.backbone}) is not " + f"in available architectures ({available_archs})" + ) - if not hasattr(self.backbone, 'output'): - raise ValueError(f"backbone ({self.backbone}) has now output method! " - f"Not a valid backbone architecture!") + if not hasattr(self.backbone, "output"): + raise ValueError( + f"backbone ({self.backbone}) has now output method! " + f"Not a valid backbone architecture!" + ) if self.backbone_name is None: self.backbone_name = self.backbone.__class__.__name__ - def output(self, input_tesnor, num_output_channels=None): + def output(self, input_tensor, num_output_channels=None): """ Invoke the backbone function with current backbone_args and backbone_kwargs to produce the model backbone block. This is a convenience property for @@ -108,11 +116,12 @@ def output(self, input_tesnor, num_output_channels=None): elif self.output_type == ModelOutputType.PART_AFFINITY_FIELD: num_outputs_channels = len(self.skeleton[0].edges) * 2 else: - raise ValueError("Model.skeletons has not been set. " - "Cannot infer num output channels.") + raise ValueError( + "Model.skeletons has not been set. " + "Cannot infer num output channels." + ) - - return self.backbone.output(input_tesnor, num_output_channels) + return self.backbone.output(input_tensor, num_output_channels) @property def name(self): @@ -125,3 +134,69 @@ def name(self): """ return self.backbone_name + @property + def down_blocks(self): + """Returns the number of pooling or striding blocks in the backbone. + + This is useful when computing valid dimensions of the input data. + + If the backbone does not provide enough information to infer this, + this is set to 0. + """ + + if hasattr(self.backbone, "down_blocks"): + return self.backbone.down_blocks + + else: + return 0 + + @property + def output_scale(self): + """Calculates output scale relative to input.""" + + if hasattr(self.backbone, "output_scale"): + return self.backbone.output_scale + + elif hasattr(self.backbone, "down_blocks") and hasattr( + self.backbone, "up_blocks" + ): + asym = self.backbone.down_blocks - self.backbone.up_blocks + return 1 / (2 ** asym) + + elif hasattr(self.backbone, "initial_stride"): + return 1 / self.backbone.initial_stride + + else: + return 1 + + @staticmethod + def _structure_model(model_dict, cls): + """Structuring hook for instantiating Model via cattrs. + + This function should be used directly with cattrs as a + structuring hook. It serves the purpose of instantiating + the appropriate backbone class from the string name. + + This is required when backbone classes do not have a + unique attribute name from which to infer the appropriate + class to use. + + Args: + model_dict: Dictionaries containing deserialized Model. + cls: Class to return (not used). + + Returns: + An instantiated Model class with the correct backbone. + + Example: + >> cattr.register_structure_hook(Model, Model.structure_model) + """ + + arch_idx = available_arch_names.index(model_dict["backbone_name"]) + backbone_cls = available_archs[arch_idx] + + return Model( + backbone=backbone_cls(**model_dict["backbone"]), + output_type=ModelOutputType(model_dict["output_type"]), + skeletons=model_dict["skeletons"], + ) diff --git a/sleap/nn/monitor.py b/sleap/nn/monitor.py index 92dcf0f1b..862acc0b5 100644 --- a/sleap/nn/monitor.py +++ b/sleap/nn/monitor.py @@ -4,10 +4,12 @@ import zmq import jsonpickle import logging + logger = logging.getLogger(__name__) from PySide2 import QtCore, QtWidgets, QtGui, QtCharts + class LossViewer(QtWidgets.QMainWindow): def __init__(self, zmq_context=None, show_controller=True, parent=None): super(LossViewer, self).__init__(parent) @@ -15,22 +17,36 @@ def __init__(self, zmq_context=None, show_controller=True, parent=None): self.show_controller = show_controller self.stop_button = None + self.redraw_batch_interval = 40 + self.batches_to_show = 200 # -1 to show all + self.ignore_outliers = False + self.log_scale = True + self.reset() self.setup_zmq(zmq_context) def __del__(self): + self.unbind() + + def close(self): + self.unbind() + super(LossViewer, self).close() + + def unbind(self): # close the zmq socket - self.sub.unbind(self.sub.LAST_ENDPOINT) - self.sub.close() - self.sub = None + if self.sub is not None: + self.sub.unbind(self.sub.LAST_ENDPOINT) + self.sub.close() + self.sub = None if self.zmq_ctrl is not None: url = self.zmq_ctrl.LAST_ENDPOINT self.zmq_ctrl.unbind(url) self.zmq_ctrl.close() self.zmq_ctrl = None # if we started out own zmq context, terminate it - if not self.ctx_given: + if not self.ctx_given and self.ctx is not None: self.ctx.term() + self.ctx = None def reset(self, what=""): self.chart = QtCharts.QtCharts.QChart() @@ -52,24 +68,35 @@ def reset(self, what=""): for s in self.series: self.series[s].pen().setColor(self.color[s]) - self.series["batch"].setMarkerSize(8.) + self.series["batch"].setMarkerSize(8.0) self.chart.addSeries(self.series["batch"]) self.chart.addSeries(self.series["epoch_loss"]) self.chart.addSeries(self.series["val_loss"]) - # self.chart.createDefaultAxes() axisX = QtCharts.QtCharts.QValueAxis() axisX.setLabelFormat("%d") axisX.setTitleText("Batches") self.chart.addAxis(axisX, QtCore.Qt.AlignBottom) - axisY = QtCharts.QtCharts.QLogValueAxis() - axisY.setLabelFormat("%f") - axisY.setLabelsVisible(True) - axisY.setMinorTickCount(1) - axisY.setTitleText("Loss") - axisY.setBase(10) + # create the different Y axes that can be used + self.axisY = dict() + + self.axisY["log"] = QtCharts.QtCharts.QLogValueAxis() + self.axisY["log"].setBase(10) + + self.axisY["linear"] = QtCharts.QtCharts.QValueAxis() + + # settings that apply to all Y axes + for axisY in self.axisY.values(): + axisY.setLabelFormat("%f") + axisY.setLabelsVisible(True) + axisY.setMinorTickCount(1) + axisY.setTitleText("Loss") + + # use the default Y axis + axisY = self.axisY["log"] if self.log_scale else self.axisY["linear"] + self.chart.addAxis(axisY, QtCore.Qt.AlignLeft) for series in self.chart.series(): @@ -86,17 +113,45 @@ def reset(self, what=""): layout.addWidget(self.chartView) if self.show_controller: + control_layout = QtWidgets.QHBoxLayout() + + field = QtWidgets.QCheckBox("Log Scale") + field.setChecked(self.log_scale) + field.stateChanged.connect(lambda x: self.toggle("log_scale")) + control_layout.addWidget(field) + + field = QtWidgets.QCheckBox("Ignore Outliers") + field.setChecked(self.ignore_outliers) + field.stateChanged.connect(lambda x: self.toggle("ignore_outliers")) + control_layout.addWidget(field) + + control_layout.addWidget(QtWidgets.QLabel("Batches to Show:")) + + field = QtWidgets.QComboBox() + self.batch_options = "200,1000,5000,All".split(",") + for opt in self.batch_options: + field.addItem(opt) + field.currentIndexChanged.connect( + lambda x: self.set_batches_to_show(self.batch_options[x]) + ) + control_layout.addWidget(field) + + control_layout.addStretch(1) + self.stop_button = QtWidgets.QPushButton("Stop Training") self.stop_button.clicked.connect(self.stop) - layout.addWidget(self.stop_button) + control_layout.addWidget(self.stop_button) + + widget = QtWidgets.QWidget() + widget.setLayout(control_layout) + layout.addWidget(widget) wid = QtWidgets.QWidget() wid.setLayout(layout) self.setCentralWidget(wid) - # Only show that last 2000 batch values - self.X = deque(maxlen=2000) - self.Y = deque(maxlen=2000) + self.X = [] + self.Y = [] self.t0 = None self.current_job_output_type = what @@ -106,9 +161,43 @@ def reset(self, what=""): self.last_batch_number = 0 self.is_running = False + def toggle(self, what): + if what == "log_scale": + self.log_scale = not self.log_scale + self.update_y_axis() + elif what == "ignore_outliers": + self.ignore_outliers = not self.ignore_outliers + elif what == "entire_history": + if self.batches_to_show > 0: + self.batches_to_show = -1 + else: + self.batches_to_show = 200 + + def set_batches_to_show(self, val): + if val.isdigit(): + self.batches_to_show = int(val) + else: + self.batches_to_show = -1 + + def update_y_axis(self): + to = "log" if self.log_scale else "linear" + # remove other axes + for name, axisY in self.axisY.items(): + if name != to: + if axisY in self.chart.axes(): + self.chart.removeAxis(axisY) + for series in self.chart.series(): + if axisY in series.attachedAxes(): + series.detachAxis(axisY) + # add axis + axisY = self.axisY[to] + self.chart.addAxis(axisY, QtCore.Qt.AlignLeft) + for series in self.chart.series(): + series.attachAxis(axisY) + def setup_zmq(self, zmq_context): # Progress monitoring - self.ctx_given = (zmq_context is not None) + self.ctx_given = zmq_context is not None self.ctx = zmq.Context() if zmq_context is None else zmq_context self.sub = self.ctx.socket(zmq.SUB) self.sub.subscribe("") @@ -127,12 +216,11 @@ def stop(self): if self.zmq_ctrl is not None: # send command to stop training logger.info("Sending command to stop training") - self.zmq_ctrl.send_string(jsonpickle.encode(dict(command="stop",))) + self.zmq_ctrl.send_string(jsonpickle.encode(dict(command="stop"))) if self.stop_button is not None: self.stop_button.setText("Stopping...") self.stop_button.setEnabled(False) - def add_datapoint(self, x, y, which="batch"): # Keep track of all batch points @@ -140,21 +228,45 @@ def add_datapoint(self, x, y, which="batch"): self.X.append(x) self.Y.append(y) - # Redraw batch ever 40 points (faster than plotting each) - if x % 40 == 0: - xs, ys = self.X, self.Y - points = [QtCore.QPointF(x, y) for x, y in zip(xs, ys)] + # Redraw batch at intervals (faster than plotting each) + if x % self.redraw_batch_interval == 0: + + if self.batches_to_show < 0 or len(self.X) < self.batches_to_show: + xs, ys = self.X, self.Y + else: + xs, ys = ( + self.X[-self.batches_to_show :], + self.Y[-self.batches_to_show :], + ) + + points = [QtCore.QPointF(x, y) for x, y in zip(xs, ys) if y > 0] self.series["batch"].replace(points) # Set X scale to show all points dx = 0.5 - self.chart.axisX().setRange(min(self.X) - dx, max(self.X) + dx) + self.chart.axisX().setRange(min(xs) - dx, max(xs) + dx) + + if self.ignore_outliers: + dy = np.ptp(ys) * 0.02 + # Set Y scale to exclude outliers + q1, q3 = np.quantile(ys, (0.25, 0.75)) + iqr = q3 - q1 # interquartile range + low = q1 - iqr * 1.5 + high = q3 + iqr * 1.5 - # Set Y scale to exclude outliers - dy = np.ptp(self.Y) * 0.04 - low, high = np.quantile(self.Y, (.02, .98)) + low = max(low, min(ys) - dy) # keep within range of data + high = min(high, max(ys) + dy) + else: + # Set Y scale to show all points + dy = np.ptp(ys) * 0.02 + low = min(ys) - dy + high = max(ys) + dy + + if self.log_scale: + low = max(low, 1e-5) # for log scale, low cannot be 0 + + self.chart.axisY().setRange(low, high) - self.chart.axisY().setRange(low - dy, high + dy) else: self.series[which].append(x, y) @@ -193,16 +305,28 @@ def check_messages(self, timeout=10): self.epoch = msg["epoch"] elif msg["event"] == "epoch_end": self.epoch_size = max(self.epoch_size, self.last_batch_number + 1) - self.add_datapoint((self.epoch+1)*self.epoch_size, msg["logs"]["loss"], "epoch_loss") + self.add_datapoint( + (self.epoch + 1) * self.epoch_size, + msg["logs"]["loss"], + "epoch_loss", + ) if "val_loss" in msg["logs"].keys(): self.last_epoch_val_loss = msg["logs"]["val_loss"] - self.add_datapoint((self.epoch+1)*self.epoch_size, msg["logs"]["val_loss"], "val_loss") + self.add_datapoint( + (self.epoch + 1) * self.epoch_size, + msg["logs"]["val_loss"], + "val_loss", + ) elif msg["event"] == "batch_end": self.last_batch_number = msg["logs"]["batch"] - self.add_datapoint((self.epoch * self.epoch_size) + msg["logs"]["batch"], msg["logs"]["loss"]) + self.add_datapoint( + (self.epoch * self.epoch_size) + msg["logs"]["batch"], + msg["logs"]["loss"], + ) self.update_runtime() + if __name__ == "__main__": app = QtWidgets.QApplication([]) win = LossViewer() @@ -210,11 +334,11 @@ def check_messages(self, timeout=10): def test_point(x=[0]): x[0] += 1 - i = x[0]+1 - win.add_datapoint(i, i%30+1) + i = x[0] + 1 + win.add_datapoint(i, i % 30 + 1) t = QtCore.QTimer() t.timeout.connect(test_point) t.start(0) - app.exec_() \ No newline at end of file + app.exec_() diff --git a/sleap/nn/peakfinding.py b/sleap/nn/peakfinding.py index d9f600d7f..603910cec 100644 --- a/sleap/nn/peakfinding.py +++ b/sleap/nn/peakfinding.py @@ -1,6 +1,7 @@ import cv2 import numpy as np + def impeaksnms_cv(I, min_thresh=0.3, sigma=3, return_val=True): """ Find peaks via non-maximum suppresion using OpenCV. """ @@ -10,12 +11,10 @@ def impeaksnms_cv(I, min_thresh=0.3, sigma=3, return_val=True): # Blur if sigma is not None: - I = cv2.GaussianBlur(I, (9,9), sigma) + I = cv2.GaussianBlur(I, (9, 9), sigma) # Maximum filter - kernel = np.array([[1,1,1], - [1,0,1], - [1,1,1]]).astype("uint8") + kernel = np.array([[1, 1, 1], [1, 0, 1], [1, 1, 1]]).astype("uint8") m = cv2.dilate(I, kernel) # Convert to points @@ -24,7 +23,7 @@ def impeaksnms_cv(I, min_thresh=0.3, sigma=3, return_val=True): # Return if return_val: - vals = np.array([I[pt[1],pt[0]] for pt in pts]) + vals = np.array([I[pt[1], pt[0]] for pt in pts]) return pts.astype("float32"), vals else: return pts.astype("float32") @@ -38,7 +37,9 @@ def find_all_peaks(confmaps, min_thresh=0.3, sigma=3): peaks_i = [] peak_vals_i = [] for i in range(confmap.shape[-1]): - peak, val = impeaksnms_cv(confmap[...,i], min_thresh=min_thresh, sigma=sigma, return_val=True) + peak, val = impeaksnms_cv( + confmap[..., i], min_thresh=min_thresh, sigma=sigma, return_val=True + ) peaks_i.append(peak) peak_vals_i.append(val) peaks.append(peaks_i) @@ -46,6 +47,7 @@ def find_all_peaks(confmaps, min_thresh=0.3, sigma=3): return peaks, peak_vals + def find_all_single_peaks(confmaps, min_thresh=0.3): """ Finds single peak for each frame/channel in a stack of conf maps. @@ -57,13 +59,17 @@ def find_all_single_peaks(confmaps, min_thresh=0.3): all_point_arrays = [] for confmap in confmaps: - peaks_vals = [image_single_peak(confmap[...,i], min_thresh) for i in range(confmap.shape[-1])] + peaks_vals = [ + image_single_peak(confmap[..., i], min_thresh) + for i in range(confmap.shape[-1]) + ] peaks_vals = [(*point, val) for point, val in peaks_vals] points_array = np.stack(peaks_vals, axis=0) all_point_arrays.append(points_array) return all_point_arrays + def image_single_peak(I, min_thresh): peak = np.unravel_index(I.argmax(), I.shape) val = I[peak] @@ -74,4 +80,4 @@ def image_single_peak(I, min_thresh): else: y, x = peak - return (x, y), val \ No newline at end of file + return (x, y), val diff --git a/sleap/nn/peakfinding_tf.py b/sleap/nn/peakfinding_tf.py index 7d7aa6d14..5c7a1d93e 100644 --- a/sleap/nn/peakfinding_tf.py +++ b/sleap/nn/peakfinding_tf.py @@ -2,15 +2,16 @@ import time import h5py +import keras import tensorflow as tf -keras = tf.keras import numpy as np from typing import Generator, Tuple from sleap.nn.util import batch + def find_maxima_tf(x): col_max = tf.reduce_max(x, axis=1) @@ -24,7 +25,8 @@ def find_maxima_tf(x): maxima = tf.concat([rows, cols], -1) # max_val = tf.reduce_max(col_max, axis=1) # should match tf.reduce_max(x, axis=[1,2]) - return maxima #, max_val + return maxima # , max_val + def impeaksnms_tf(I, min_thresh=0.3): @@ -32,9 +34,7 @@ def impeaksnms_tf(I, min_thresh=0.3): # less than min_thresh are set to 0. It = tf.cast(I > min_thresh, I.dtype) * I - kernel = np.array([[0, 0, 0], - [0, -1, 0], - [0, 0, 0]])[..., None] + kernel = np.array([[0, 0, 0], [0, -1, 0], [0, 0, 0]])[..., None] # kernel = np.array([[1, 1, 1], # [1, 0, 1], # [1, 1, 1]])[..., None] @@ -46,90 +46,117 @@ def impeaksnms_tf(I, min_thresh=0.3): return inds, peak_vals -def find_peaks_tf(confmaps, min_thresh=0.3, upsample_factor: int = 1): - n, h, w, c = confmaps.get_shape().as_list() +def find_peaks_tf( + confmaps, + confmaps_shape, + min_thresh=0.3, + upsample_factor: int = 1, + win_size: int = 5, +): + # n, h, w, c = confmaps.get_shape().as_list() + + h, w, c = confmaps_shape - unrolled_confmaps = tf.reshape(tf.transpose(confmaps, perm=[0, 3, 1, 2]), [-1, h, w, 1]) # nc, h, w, 1 + unrolled_confmaps = tf.reshape( + tf.transpose(confmaps, perm=[0, 3, 1, 2]), [-1, h, w, 1] + ) # (nc, h, w, 1) peak_inds, peak_vals = impeaksnms_tf(unrolled_confmaps, min_thresh=min_thresh) - channel_sample, y, x, _ = tf.split(peak_inds, 4, axis=1) + channel_sample_ind, y, x, _ = tf.split(peak_inds, 4, axis=1) - channel = tf.floormod(channel_sample, c) - sample = tf.floordiv(channel_sample, c) + channel_ind = tf.floormod(channel_sample_ind, c) + sample_ind = tf.floordiv(channel_sample_ind, c) - peaks = tf.concat([sample, y, x, channel], axis=1) + peaks = tf.concat([sample_ind, y, x, channel_ind], axis=1) # (nc, 4) # If we have run prediction on low res and need to upsample the peaks # to a higher resolution. Compute sub-pixel accurate peaks # from these approximate peaks and return the upsampled sub-pixel peaks. if upsample_factor > 1: - win_size = 5 # Must be odd offset = (win_size - 1) / 2 # Get the boxes coordinates centered on the peaks, normalized to image # coordinates - box_ind = tf.squeeze(tf.cast(channel_sample, tf.int32)) - top_left = (tf.to_float(peaks[:, 1:3]) + - tf.constant([-offset, -offset], dtype='float32')) / (h - 1.0) - bottom_right = (tf.to_float(peaks[:, 1:3]) + tf.constant([offset, offset], dtype='float32')) / (w - 1.0) + box_ind = tf.squeeze(tf.cast(channel_sample_ind, tf.int32)) + top_left = ( + tf.cast(peaks[:, 1:3], tf.float32) + + tf.constant([-offset, -offset], dtype="float32") + ) / (h - 1.0) + bottom_right = ( + tf.cast(peaks[:, 1:3], tf.float32) + + tf.constant([offset, offset], dtype="float32") + ) / (w - 1.0) boxes = tf.concat([top_left, bottom_right], axis=1) small_windows = tf.image.crop_and_resize( - unrolled_confmaps, - boxes, - box_ind, - crop_size=[win_size, win_size]) + unrolled_confmaps, boxes, box_ind, crop_size=[win_size, win_size] + ) + # Upsample cropped windows windows = tf.image.resize_bicubic( - small_windows, - [upsample_factor*win_size, upsample_factor*win_size]) + small_windows, [upsample_factor * win_size, upsample_factor * win_size] + ) windows = tf.squeeze(windows) - windows_peaks = find_maxima_tf(windows) - windows_peaks = windows_peaks / win_size - else: - windows_peaks = None - return peaks, peak_vals, windows_peaks + # Find global maximum of each window + windows_peaks = find_maxima_tf(windows) # [row_ind, col_ind] ==> (nc, 2) + + # Adjust back to resolution before upsampling + windows_peaks = tf.cast(windows_peaks, tf.float32) / tf.cast( + upsample_factor, tf.float32 + ) + + # Convert to offsets relative to the original peaks (center of cropped windows) + windows_offsets = windows_peaks - tf.cast(offset, tf.float32) # (nc, 2) + windows_offsets = tf.pad( + windows_offsets, [[0, 0], [1, 1]], mode="CONSTANT", constant_values=0 + ) # (nc, 4) + + # Apply offsets + peaks = tf.cast(peaks, tf.float32) + windows_offsets + + return peaks, peak_vals + # Blurring: # Ref: https://stackoverflow.com/questions/52012657/how-to-make-a-2d-gaussian-filter-in-tensorflow -def gaussian_kernel(size: int, - mean: float, - std: float, - ): +def gaussian_kernel(size: int, mean: float, std: float): """Makes 2D gaussian Kernel for convolution.""" d = tf.distributions.Normal(mean, std) - vals = d.prob(tf.range(start = -size, limit = size + 1, dtype = tf.float32)) - gauss_kernel = tf.einsum('i,j->ij', - vals, - vals) + vals = d.prob(tf.range(start=-size, limit=size + 1, dtype=tf.float32)) + gauss_kernel = tf.einsum("i,j->ij", vals, vals) return gauss_kernel / tf.reduce_sum(gauss_kernel) -# Now we can do peak finding on the GPU like this: -def peak_tf_inference(model, data, - min_thresh: float = 0.3, - gaussian_size: int = 9, - gaussian_sigma: float = 3.0, - upsample_factor: int = 1, - downsample_factor: int = 1, - return_confmaps: bool = False): + +def peak_tf_inference( + model, + data, + confmaps_shape: Tuple[int], + min_thresh: float = 0.3, + gaussian_size: int = 9, + gaussian_sigma: float = 3.0, + upsample_factor: int = 1, + return_confmaps: bool = False, + batch_size: int = 4, + win_size: int = 7, +): sess = keras.backend.get_session() + # TODO: Unfuck this. confmaps = model.outputs[-1] + h, w, c = confmaps_shape - n, h, w, c = confmaps.get_shape().as_list() - - if gaussian_size: + if gaussian_size > 0 and gaussian_sigma > 0: # Make Gaussian Kernel with desired specs. gauss_kernel = gaussian_kernel(size=gaussian_size, mean=0.0, std=gaussian_sigma) - # Expand dimensions of `gauss_kernel` for `tf.nn.seprable_conv2d` signature. + # Expand dimensions of `gauss_kernel` for `tf.nn.separable_conv2d` signature. gauss_kernel = tf.tile(gauss_kernel[:, :, tf.newaxis, tf.newaxis], [1, 1, c, 1]) # Create a pointwise filter that does nothing, we are using separable convultions to blur @@ -137,91 +164,96 @@ def peak_tf_inference(model, data, pointwise_filter = tf.eye(c, batch_shape=[1, 1]) # Convolve. - blurred_confmaps = tf.nn.separable_conv2d(confmaps, gauss_kernel, pointwise_filter, - strides=[1, 1, 1, 1], padding='SAME') - - inds, peak_vals, windows = find_peaks_tf(blurred_confmaps, min_thresh=min_thresh, - upsample_factor=upsample_factor) - else: - inds, peak_vals, windows = find_peaks_tf(confmaps, min_thresh=min_thresh, - upsample_factor=upsample_factor) + confmaps = tf.nn.separable_conv2d( + confmaps, + gauss_kernel, + pointwise_filter, + strides=[1, 1, 1, 1], + padding="SAME", + ) + + # Setup peak finding computations. + peaks, peak_vals = find_peaks_tf( + confmaps, + confmaps_shape=confmaps_shape, + min_thresh=min_thresh, + upsample_factor=upsample_factor, + win_size=win_size, + ) # We definitely want to capture the peaks in the output # We will map the tensorflow outputs onto a dict to return - outputs_dict = dict(peaks=inds, peak_vals=peak_vals) - - if upsample_factor > 1: - outputs_dict["windows"] = windows + outputs_dict = dict(peaks=peaks, peak_vals=peak_vals) if return_confmaps: outputs_dict["confmaps"] = confmaps # Convert dict to list of keys and list of tensors (to evaluate) - outputs_keys, outputs_vals = list(outputs_dict.keys()), list(outputs_dict.values()) + outputs_keys, output_tensors = ( + list(outputs_dict.keys()), + list(outputs_dict.values()), + ) - peaks = [] - peak_vals = [] - windows = [] - confmaps = [] - - for batch_number, row_offset, data_batch in batch(data, batch_size=2): + # Run the graph and retrieve output arrays. + peaks_arr = [] + peak_vals_arr = [] + confmaps_arr = [] + for batch_number, row_offset, data_batch in batch(data, batch_size=batch_size): # This does the actual evaluation - outputs = sess.run(outputs_vals, feed_dict={ model.input: data_batch }) + outputs_arr = sess.run(output_tensors, feed_dict={model.input: data_batch}) # Convert list of results to dict using saved list of keys - outputs_dict = dict(zip(outputs_keys, outputs)) + outputs_arr_dict = dict(zip(outputs_keys, outputs_arr)) - batch_peaks = outputs_dict["peaks"] + batch_peaks = outputs_arr_dict["peaks"] # First column should match row number in full data matrix, # so we add row offset of batch to row number in batch matrix. - batch_peaks[:,0] += row_offset - - peaks.append(batch_peaks) - peak_vals.append(outputs_dict["peak_vals"]) + batch_peaks[:, 0] += row_offset - if "windows" in outputs_dict: - windows.append(outputs_dict["windows"]) + peaks_arr.append(batch_peaks) + peak_vals_arr.append(outputs_arr_dict["peak_vals"]) if "confmaps" in outputs_dict: - confmaps.append(outputs_dict["confmaps"]) + confmaps.append(outputs_arr_dict["confmaps"]) - peaks = np.concatenate(peaks) - peak_vals = np.concatenate(peak_vals) - confmaps = np.concatenate(confmaps) if len(confmaps) else None + peaks_arr = np.concatenate(peaks_arr, axis=0) + peak_vals_arr = np.concatenate(peak_vals_arr, axis=0) + confmaps_arr = np.concatenate(confmaps_arr, axis=0) if len(confmaps_arr) else None # Extract frame and node index columns - frame_node_idx = peaks[:, [0, 3]] + sample_channel_ind = peaks_arr[:, [0, 3]] # (nc, 2) # Extract X and Y columns - peak_points = peaks[:,[1,2]].astype("float") - - # Add offset from upsampling window peak if upsampling - if upsample_factor > 1 and len(windows): - windows = np.concatenate(windows) - peak_points += windows/upsample_factor - - if downsample_factor > 1: - peak_points /= downsample_factor - - # Swap the X and Y columns (order was [row idx, col idx]) - peak_points = peak_points[:,[1,0]] + peak_points = peaks_arr[:, [2, 1]].astype("float") # [x, y] ==> (nc, 2) # Use indices to convert matrices to lists of lists # (this matches the format of cpu-based peak-finding) - peak_list, peak_val_list = split_matrices_by_double_index(frame_node_idx, peak_points, peak_vals) + peak_list, peak_val_list = split_matrices_by_double_index( + sample_channel_ind, + peak_points, + peak_vals_arr, + n_samples=len(data), + n_channels=c, + ) return peak_list, peak_val_list, confmaps -def split_matrices_by_double_index(idxs, *data_list): + +def split_matrices_by_double_index(idxs, *data_list, n_samples=None, n_channels=None): """Convert data matrices to lists of lists expected by other functions.""" # Return empty array if there are no idxs - if len(idxs) == 0: return [], [] + if len(idxs) == 0: + return [], [] # Determine the list length for major and minor indices - max_idx_vals = np.max(idxs, axis=0).astype("int") + 1 + if n_samples is None: + n_samples = np.max(idxs[:, 0]) + 1 + + if n_channels is None: + n_channels = np.max(idxs[:, 1]) + 1 # We can accept a variable number of data matrices data_matrix_count = len(data_list) @@ -230,18 +262,18 @@ def split_matrices_by_double_index(idxs, *data_list): r = [[] for _ in range(data_matrix_count)] # Loop over major index (frame) - for i in range(max_idx_vals[0]): + for t in range(n_samples): # Empty list for this value of major index # for results from each data matrix major = [[] for _ in range(data_matrix_count)] # Loop over minor index (node) - for j in range(max_idx_vals[1]): + for c in range(n_channels): # Use idxs matrix to determine which rows # to retrieve from each data matrix - mask = np.all((idxs == [i,j]), axis = 1) + mask = np.all((idxs == [t, c]), axis=1) # Get rows from each data matrix for data_matrix_idx, matrix in enumerate(data_list): @@ -252,4 +284,4 @@ def split_matrices_by_double_index(idxs, *data_list): for data_matrix_idx in range(data_matrix_count): r[data_matrix_idx].append(major[data_matrix_idx]) - return r \ No newline at end of file + return r diff --git a/sleap/nn/peakmatching.py b/sleap/nn/peakmatching.py index 5023e81fd..228bd8ee1 100644 --- a/sleap/nn/peakmatching.py +++ b/sleap/nn/peakmatching.py @@ -4,6 +4,7 @@ from sleap.instance import LabeledFrame, PredictedPoint, PredictedInstance from sleap.info.metrics import calculate_pairwise_cost + def match_single_peaks_frame(points_array, skeleton, transform, img_idx): """ Make instance from points array returned by single peak finding. @@ -13,10 +14,11 @@ def match_single_peaks_frame(points_array, skeleton, transform, img_idx): Returns: PredictedInstance, or None if no points. """ - if points_array.shape[0] == 0: return None + if points_array.shape[0] == 0: + return None # apply inverse transform to points - points_array[...,0:2] = transform.invert(img_idx, points_array[...,0:2]) + points_array[..., 0:2] = transform.invert(img_idx, points_array[..., 0:2]) pts = dict() for i, node in enumerate(skeleton.nodes): @@ -29,11 +31,14 @@ def match_single_peaks_frame(points_array, skeleton, transform, img_idx): matched_instance = None if len(pts) > 0: # FIXME: how should we calculate score for instance? - inst_score = np.sum(points_array[...,2]) / len(pts) - matched_instance = PredictedInstance(skeleton=skeleton, points=pts, score=inst_score) + inst_score = np.sum(points_array[..., 2]) / len(pts) + matched_instance = PredictedInstance( + skeleton=skeleton, points=pts, score=inst_score + ) return matched_instance + def match_single_peaks_all(points_arrays, skeleton, video, transform): """ Make labeled frames for the results of single peak finding. @@ -52,6 +57,7 @@ def match_single_peaks_all(points_arrays, skeleton, video, transform): predicted_frames.append(new_lf) return predicted_frames + def improfile(I, p0, p1, max_points=None): """ Returns values of the image I evaluated along the line formed @@ -73,7 +79,7 @@ def improfile(I, p0, p1, max_points=None): I = np.squeeze(I) # Find number of points to extract - n = np.sqrt((p0[0] - p1[0])**2 + (p0[1] - p1[1])**2) + n = np.sqrt((p0[0] - p1[0]) ** 2 + (p0[1] - p1[1]) ** 2) n = max(n, 1) if max_points is not None: n = min(n, max_points) @@ -84,15 +90,23 @@ def improfile(I, p0, p1, max_points=None): y = np.round(np.linspace(p0[1], p1[1], n)).astype("int32") # Extract values and concatenate into vector - vals = np.stack([I[yi,xi] for xi, yi in zip(x,y)]) + vals = np.stack([I[yi, xi] for xi, yi in zip(x, y)]) return vals -def match_peaks_frame(peaks_t, peak_vals_t, pafs_t, skeleton, transform, img_idx, - min_score_to_node_ratio=0.4, - min_score_midpts=0.05, - min_score_integral=0.8, - add_last_edge=False, - single_per_crop=True): + +def match_peaks_frame( + peaks_t, + peak_vals_t, + pafs_t, + skeleton, + transform, + img_idx, + min_score_to_node_ratio=0.4, + min_score_midpts=0.05, + min_score_integral=0.8, + add_last_edge=False, + single_per_crop=True, +): """ Matches single frame """ @@ -116,8 +130,8 @@ def match_peaks_frame(peaks_t, peak_vals_t, pafs_t, skeleton, transform, img_idx for k, edge in enumerate(skeleton.edge_names): src_node_idx = skeleton.node_to_index(edge[0]) dst_node_idx = skeleton.node_to_index(edge[1]) - paf_x = pafs_t[...,2*k] - paf_y = pafs_t[...,2*k+1] + paf_x = pafs_t[..., 2 * k] + paf_y = pafs_t[..., 2 * k + 1] # Make sure matrix has rows for these nodes if len(peaks_t) <= src_node_idx or len(peaks_t) <= dst_node_idx: @@ -156,16 +170,23 @@ def match_peaks_frame(peaks_t, peak_vals_t, pafs_t, skeleton, transform, img_idx # Compute score score_midpts = vec_x * vec[0] + vec_y * vec[1] - score_with_dist_prior = np.mean(score_midpts) + min(0.5 * paf_x.shape[0] / norm - 1, 0) + score_with_dist_prior = np.mean(score_midpts) + min( + 0.5 * paf_x.shape[0] / norm - 1, 0 + ) score_integral = np.mean(score_midpts > min_score_midpts) - if score_with_dist_prior > 0 and score_integral > min_score_integral: + if ( + score_with_dist_prior > 0 + and score_integral > min_score_integral + ): connection_candidates.append([i, j, score_with_dist_prior]) # Sort candidates for current edge by descending score - connection_candidates = sorted(connection_candidates, key=lambda x: x[2], reverse=True) + connection_candidates = sorted( + connection_candidates, key=lambda x: x[2], reverse=True + ) # Add to list of candidates for next step - connection = np.zeros((0,5)) # src_id, dst_id, paf_score, i, j + connection = np.zeros((0, 5)) # src_id, dst_id, paf_score, i, j for candidate in connection_candidates: i, j, score = candidate # Add to connections if node is not already included @@ -180,20 +201,27 @@ def match_peaks_frame(peaks_t, peak_vals_t, pafs_t, skeleton, transform, img_idx connection_all.append(connection) # Greedy matching of each edge candidate set - subset = -1 * np.ones((0, len(skeleton.nodes)+2)) # ids, overall score, number of parts - candidate = np.array([y for x in peaks_t for y in x]) # flattened set of all points - candidate_scores = np.array([y for x in peak_vals_t for y in x]) # flattened set of all peak scores + subset = -1 * np.ones( + (0, len(skeleton.nodes) + 2) + ) # ids, overall score, number of parts + candidate = np.array([y for x in peaks_t for y in x]) # flattened set of all points + candidate_scores = np.array( + [y for x in peak_vals_t for y in x] + ) # flattened set of all peak scores for k, edge in enumerate(skeleton.edge_names): # No matches for this edge if k in special_k: continue # Get IDs for current connection - partAs = connection_all[k][:,0] - partBs = connection_all[k][:,1] + partAs = connection_all[k][:, 0] + partBs = connection_all[k][:, 1] # Get edge - indexA, indexB = (skeleton.node_to_index(edge[0]), skeleton.node_to_index(edge[1])) + indexA, indexB = ( + skeleton.node_to_index(edge[0]), + skeleton.node_to_index(edge[1]), + ) # Loop through all candidates for current edge for i in range(len(connection_all[k])): @@ -209,18 +237,24 @@ def match_peaks_frame(peaks_t, peak_vals_t, pafs_t, skeleton, transform, img_idx # One of the two candidate points found in matched subset if found == 1: j = subset_idx[0] - if subset[j][indexB] != partBs[i]: # did we already assign this part? - subset[j][indexB] = partBs[i] # assign part - subset[j][-1] += 1 # increment instance part counter - subset[j][-2] += candidate_scores[int(partBs[i])] + connection_all[k][i][2] # add peak + edge score + if subset[j][indexB] != partBs[i]: # did we already assign this part? + subset[j][indexB] = partBs[i] # assign part + subset[j][-1] += 1 # increment instance part counter + subset[j][-2] += ( + candidate_scores[int(partBs[i])] + connection_all[k][i][2] + ) # add peak + edge score # Both candidate points found in matched subset elif found == 2: - j1, j2 = subset_idx # get indices in matched subset - membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2] # count number of instances per body parts + j1, j2 = subset_idx # get indices in matched subset + membership = ( + (subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int) + )[ + :-2 + ] # count number of instances per body parts # All body parts are disjoint, merge them if np.all(membership < 2): - subset[j1][:-2] += (subset[j2][:-2] + 1) + subset[j1][:-2] += subset[j2][:-2] + 1 subset[j1][-2:] += subset[j2][-2:] subset[j1][-2] += connection_all[k][i][2] subset = np.delete(subset, j2, axis=0) @@ -229,24 +263,30 @@ def match_peaks_frame(peaks_t, peak_vals_t, pafs_t, skeleton, transform, img_idx else: subset[j1][indexB] = partBs[i] subset[j1][-1] += 1 - subset[j1][-2] += candidate_scores[partBs[i].astype(int)] + connection_all[k][i][2] + subset[j1][-2] += ( + candidate_scores[partBs[i].astype(int)] + + connection_all[k][i][2] + ) # Neither point found, create a new subset (if not the last edge) - elif found == 0 and (add_last_edge or (k < (len(skeleton.edges)-1))): - row = -1 * np.ones(len(skeleton.nodes)+2) - row[indexA] = partAs[i] # ID - row[indexB] = partBs[i] # ID - row[-1] = 2 # initial count - row[-2] = sum(candidate_scores[connection_all[k][i, :2].astype(int)]) + connection_all[k][i][2] # score - subset = np.vstack([subset, row]) # add to matched subset + elif found == 0 and (add_last_edge or (k < (len(skeleton.edges) - 1))): + row = -1 * np.ones(len(skeleton.nodes) + 2) + row[indexA] = partAs[i] # ID + row[indexB] = partBs[i] # ID + row[-1] = 2 # initial count + row[-2] = ( + sum(candidate_scores[connection_all[k][i, :2].astype(int)]) + + connection_all[k][i][2] + ) # score + subset = np.vstack([subset, row]) # add to matched subset # Filter small instances - score_to_node_ratio = subset[:,-2] / subset[:,-1] + score_to_node_ratio = subset[:, -2] / subset[:, -1] subset = subset[score_to_node_ratio > min_score_to_node_ratio, :] - # apply inverse transform to points + # Apply inverse transform to points to return to full resolution, uncropped image coordinates if candidate.shape[0] > 0: - candidate[...,0:2] = transform.invert(img_idx, candidate[...,0:2]) + candidate[..., 0:2] = transform.invert(img_idx, candidate[..., 0:2]) # Done with all the matching! Gather the data matched_instances_t = [] @@ -257,79 +297,123 @@ def match_peaks_frame(peaks_t, peak_vals_t, pafs_t, skeleton, transform, img_idx for i, node_name in enumerate(skeleton.node_names): if match[i] >= 0: match_idx = int(match[i]) - pt = PredictedPoint(x=candidate[match_idx, 0], y=candidate[match_idx, 1], - score=candidate_scores[match_idx]) + pt = PredictedPoint( + x=candidate[match_idx, 0], + y=candidate[match_idx, 1], + score=candidate_scores[match_idx], + ) pts[node_name] = pt if len(pts): - matched_instances_t.append(PredictedInstance(skeleton=skeleton, - points=pts, - score=match[-2])) + matched_instances_t.append( + PredictedInstance(skeleton=skeleton, points=pts, score=match[-2]) + ) # For centroid crop just return instance closest to centroid - if single_per_crop and len(matched_instances_t) > 1 and transform.is_cropped: + # if single_per_crop and len(matched_instances_t) > 1 and transform.is_cropped: - crop_centroid = np.array(((transform.crop_size//2, transform.crop_size//2),)) # center of crop box - crop_centroid = transform.invert(img_idx, crop_centroid) # relative to original image + # crop_centroid = np.array(((transform.crop_size//2, transform.crop_size//2),)) # center of crop box + # crop_centroid = transform.invert(img_idx, crop_centroid) # relative to original image - # sort by distance from crop centroid - matched_instances_t.sort(key=lambda inst: np.linalg.norm(inst.centroid - crop_centroid)) + # # sort by distance from crop centroid + # matched_instances_t.sort(key=lambda inst: np.linalg.norm(inst.centroid - crop_centroid)) - # logger.debug(f"SINGLE_INSTANCE_PER_CROP: crop has {len(matched_instances_t)} instances, filter to 1.") + # # logger.debug(f"SINGLE_INSTANCE_PER_CROP: crop has {len(matched_instances_t)} instances, filter to 1.") - # just use closest - matched_instances_t = matched_instances_t[0:1] + # # just use closest + # matched_instances_t = matched_instances_t[0:1] + + if single_per_crop and len(matched_instances_t) > 1 and transform.is_cropped: + # Just keep highest scoring instance + matched_instances_t = [matched_instances_t[0]] return matched_instances_t -def match_peaks_paf(peaks, peak_vals, pafs, skeleton, - video, transform, - min_score_to_node_ratio=0.4, min_score_midpts=0.05, - min_score_integral=0.8, add_last_edge=False, single_per_crop=True, - **kwargs): + +def match_peaks_paf( + peaks, + peak_vals, + pafs, + skeleton, + video, + transform, + min_score_to_node_ratio=0.4, + min_score_midpts=0.05, + min_score_integral=0.8, + add_last_edge=False, + single_per_crop=True, + **kwargs +): """ Computes PAF-based peak matching via greedy assignment """ # Process each frame predicted_frames = [] - for img_idx, (peaks_t, peak_vals_t, pafs_t) in enumerate(zip(peaks, peak_vals, pafs)): - instances = match_peaks_frame(peaks_t, peak_vals_t, pafs_t, skeleton, - transform, img_idx, - min_score_to_node_ratio=min_score_to_node_ratio, - min_score_midpts=min_score_midpts, - min_score_integral=min_score_integral, - add_last_edge=add_last_edge, - single_per_crop=single_per_crop) + for img_idx, (peaks_t, peak_vals_t, pafs_t) in enumerate( + zip(peaks, peak_vals, pafs) + ): + instances = match_peaks_frame( + peaks_t, + peak_vals_t, + pafs_t, + skeleton, + transform, + img_idx, + min_score_to_node_ratio=min_score_to_node_ratio, + min_score_midpts=min_score_midpts, + min_score_integral=min_score_integral, + add_last_edge=add_last_edge, + single_per_crop=single_per_crop, + ) frame_idx = transform.get_frame_idxs(img_idx) - predicted_frames.append(LabeledFrame(video=video, frame_idx=frame_idx, instances=instances)) + predicted_frames.append( + LabeledFrame(video=video, frame_idx=frame_idx, instances=instances) + ) # Combine LabeledFrame objects for the same video frame predicted_frames = LabeledFrame.merge_frames(predicted_frames, video=video) return predicted_frames -def match_peaks_paf_par(peaks, peak_vals, pafs, skeleton, - video, transform, - min_score_to_node_ratio=0.4, - min_score_midpts=0.05, - min_score_integral=0.8, - add_last_edge=False, - single_per_crop=True, - pool=None, **kwargs): + +def match_peaks_paf_par( + peaks, + peak_vals, + pafs, + skeleton, + video, + transform, + min_score_to_node_ratio=0.4, + min_score_midpts=0.05, + min_score_integral=0.8, + add_last_edge=False, + single_per_crop=True, + pool=None, + **kwargs +): """ Parallel version of PAF peak matching """ if pool is None: + import multiprocessing + pool = multiprocessing.Pool() futures = [] - for img_idx, (peaks_t, peak_vals_t, pafs_t) in enumerate(zip(peaks, peak_vals, pafs)): - future = pool.apply_async(match_peaks_frame, - [peaks_t, peak_vals_t, pafs_t, skeleton], - dict(transform=transform, img_idx=img_idx, - min_score_to_node_ratio=min_score_to_node_ratio, - min_score_midpts=min_score_midpts, - min_score_integral=min_score_integral, - add_last_edge=add_last_edge, - single_per_crop=single_per_crop,)) + for img_idx, (peaks_t, peak_vals_t, pafs_t) in enumerate( + zip(peaks, peak_vals, pafs) + ): + future = pool.apply_async( + match_peaks_frame, + [peaks_t, peak_vals_t, pafs_t, skeleton], + dict( + transform=transform, + img_idx=img_idx, + min_score_to_node_ratio=min_score_to_node_ratio, + min_score_midpts=min_score_midpts, + min_score_integral=min_score_integral, + add_last_edge=add_last_edge, + single_per_crop=single_per_crop, + ), + ) futures.append(future) predicted_frames = [] @@ -342,32 +426,46 @@ def match_peaks_paf_par(peaks, peak_vals, pafs, skeleton, # an expensive operation. for i in range(len(instances)): points = {node.name: point for node, point in instances[i].nodes_points} - instances[i] = PredictedInstance(skeleton=skeleton, points=points, score=instances[i].score) + instances[i] = PredictedInstance( + skeleton=skeleton, points=points, score=instances[i].score + ) - predicted_frames.append(LabeledFrame(video=video, frame_idx=frame_idx, instances=instances)) + predicted_frames.append( + LabeledFrame(video=video, frame_idx=frame_idx, instances=instances) + ) # Combine LabeledFrame objects for the same video frame predicted_frames = LabeledFrame.merge_frames(predicted_frames, video=video) return predicted_frames -def instances_nms(instances: List[PredictedInstance], thresh: float=4) -> List[PredictedInstance]: + +def instances_nms( + instances: List[PredictedInstance], thresh: float = 4 +) -> List[PredictedInstance]: """Remove overlapping instances from list.""" - if len(instances) <= 1: return + if len(instances) <= 1: + return # Look for overlapping instances - overlap_matrix = calculate_pairwise_cost(instances, instances, - cost_function = lambda x: np.nan if all(np.isnan(x)) else np.nanmean(x)) + overlap_matrix = calculate_pairwise_cost( + instances, + instances, + cost_function=lambda x: np.nan if all(np.isnan(x)) else np.nanmean(x), + ) # Set diagonals over threshold since an instance doesn't overlap with itself - np.fill_diagonal(overlap_matrix, thresh+1) - overlap_matrix[np.isnan(overlap_matrix)] = thresh+1 + np.fill_diagonal(overlap_matrix, thresh + 1) + overlap_matrix[np.isnan(overlap_matrix)] = thresh + 1 instances_to_remove = [] def sort_funct(inst_idx): # sort by number of points in instance, then by prediction score (desc) - return (len(instances[inst_idx].nodes), -getattr(instances[inst_idx], "score", 0)) + return ( + len(instances[inst_idx].nodes), + -getattr(instances[inst_idx], "score", 0), + ) while np.nanmin(overlap_matrix) < thresh: # Find the pair of instances with greatest overlap @@ -379,8 +477,8 @@ def sort_funct(inst_idx): keep_idx = idxs[-1] # Remove this instance from overlap matrix - overlap_matrix[pick_idx, :] = thresh+1 - overlap_matrix[:, pick_idx] = thresh+1 + overlap_matrix[pick_idx, :] = thresh + 1 + overlap_matrix[:, pick_idx] = thresh + 1 # Add to list of instances that we'll remove. # We'll remove these later so list index doesn't change now. @@ -389,4 +487,4 @@ def sort_funct(inst_idx): # Remove selected instances from list # Note that we're modifying the original list in place for inst in instances_to_remove: - instances.remove(inst) \ No newline at end of file + instances.remove(inst) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 2facf5bb3..413aaa573 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -28,13 +28,14 @@ class ShiftedInstance: Args: parent: The Instance that this optical flow shifted instance is derived from. """ - parent: Union[Instance, 'ShiftedInstance'] = attr.ib() + + parent: Union[Instance, "ShiftedInstance"] = attr.ib() frame: Union[LabeledFrame, None] = attr.ib() points: np.ndarray = attr.ib() @property @functools.lru_cache() - def source(self) -> 'Instance': + def source(self) -> "Instance": """ Recursively discover root instance to a chain of flow shifted instances. @@ -67,7 +68,7 @@ def frame_idx(self) -> int: """ return self.frame.frame_idx - def points_array(self, *args, **kwargs): + def get_points_array(self, *args, **kwargs): """ Return the ShiftedInstance as a numpy array. ShiftedInstance stores its points as an array always, unlike the Instance class. This method provides @@ -79,6 +80,7 @@ def points_array(self, *args, **kwargs): """ return self.points + @attr.s(slots=True) class Tracks: instances: Dict[int, list] = attr.ib(default=attr.Factory(dict)) @@ -91,13 +93,21 @@ def get_frame_instances(self, frame_idx: int, max_shift=None): # Filter if max_shift is not None: - instances = [instance for instance in instances if isinstance(instance, Instance) or ( - isinstance(instance, ShiftedInstance) and ( - (frame_idx - instance.source.frame_idx) <= max_shift))] + instances = [ + instance + for instance in instances + if isinstance(instance, Instance) + or ( + isinstance(instance, ShiftedInstance) + and ((frame_idx - instance.source.frame_idx) <= max_shift) + ) + ] return instances - def add_instance(self, instance: Union[Instance, 'ShiftedInstance'], frame_index: int): + def add_instance( + self, instance: Union[Instance, "ShiftedInstance"], frame_index: int + ): frame_instances = self.instances.get(frame_index, []) frame_instances.append(instance) self.instances[frame_index] = frame_instances @@ -113,11 +123,17 @@ def get_last_known(self, curr_frame_index: int = None, max_shift: int = None): return list(self.last_known_instance.values()) else: if max_shift is None: - return [i for i in self.last_known_instance.values() - if i.track == curr_frame_index] + return [ + i + for i in self.last_known_instance.values() + if i.track == curr_frame_index + ] else: - return [i for i in self.last_known_instance.values() - if (curr_frame_index-i.frame_idx) < max_shift] + return [ + i + for i in self.last_known_instance.values() + if (curr_frame_index - i.frame_idx) < max_shift + ] def update_track_last_known(self, frame: LabeledFrame, max_shift: int = None): for i in frame.instances: @@ -126,9 +142,11 @@ def update_track_last_known(self, frame: LabeledFrame, max_shift: int = None): # Remove tracks from the dict that have exceeded the max_shift horizon if max_shift is not None: - del_tracks = [track - for track, instance in self.last_known_instance.items() - if (frame.frame_idx-instance.frame_idx) > max_shift] + del_tracks = [ + track + for track, instance in self.last_known_instance.items() + if (frame.frame_idx - instance.frame_idx) > max_shift + ] for key in del_tracks: del self.last_known_instance[key] @@ -153,7 +171,7 @@ class FlowShiftTracker: """ window: int = 10 - of_win_size: Tuple = (21,21) + of_win_size: Tuple = (21, 21) of_max_level: int = 3 of_max_count: int = 30 of_epsilon: float = 0.01 @@ -168,7 +186,7 @@ def __attrs_post_init__(self): def _fix_img(self, img: np.ndarray): # Drop single channel dimension and convert to uint8 in [0, 255] range - curr_img = (np.squeeze(img)*255).astype(np.uint8) + curr_img = (np.squeeze(img) * 255).astype(np.uint8) np.clip(curr_img, 0, 255) # If we still have 3 dimensions the image is color, need to convert @@ -178,9 +196,7 @@ def _fix_img(self, img: np.ndarray): return curr_img - def process(self, - imgs: np.ndarray, - labeled_frames: List[LabeledFrame]): + def process(self, imgs: np.ndarray, labeled_frames: List[LabeledFrame]): """ Flow shift track a batch of frames with matched instances for each frame represented as a list of LabeledFrame's. @@ -215,14 +231,16 @@ def process(self, # known instance for each track. Do this for the last frame and # skip on the first frame. if img_idx > 0: - self.tracks.update_track_last_known(labeled_frames[img_idx-1], max_shift=None) + self.tracks.update_track_last_known( + labeled_frames[img_idx - 1], max_shift=None + ) # Copy the actual frame index for this labeled frame, we will # use this a lot. self.last_frame_index = t t = frame.frame_idx - instances_pts = [i.points_array() for i in frame.instances] + instances_pts = [i.get_points_array() for i in frame.instances] # If we do not have any active tracks, we will spawn one for each # matched instance and continue to the next frame. @@ -232,30 +250,45 @@ def process(self, instance.track = Track(spawned_on=t, name=f"{i}") self.tracks.add_instance(instance, frame_index=t) - logger.debug(f"[t = {t}] Created {len(self.tracks.tracks)} initial tracks") + logger.debug( + f"[t = {t}] Created {len(self.tracks.tracks)} initial tracks" + ) self.last_img = self._fix_img(imgs[img_idx].copy()) continue # Get all points in reference frame - instances_ref = self.tracks.get_frame_instances(self.last_frame_index, max_shift=self.window - 1) - pts_ref = [instance.points_array() for instance in instances_ref] - - tmp = min([instance.frame_idx for instance in instances_ref] + - [instance.source.frame_idx for instance in instances_ref - if isinstance(instance, ShiftedInstance)]) + instances_ref = self.tracks.get_frame_instances( + self.last_frame_index, max_shift=self.window - 1 + ) + pts_ref = [instance.get_points_array() for instance in instances_ref] + + tmp = min( + [instance.frame_idx for instance in instances_ref] + + [ + instance.source.frame_idx + for instance in instances_ref + if isinstance(instance, ShiftedInstance) + ] + ) logger.debug(f"[t = {t}] Using {len(instances_ref)} refs back to t = {tmp}") curr_img = self._fix_img(imgs[img_idx].copy()) - pts_fs, status, err = \ - cv2.calcOpticalFlowPyrLK(self.last_img, curr_img, - (np.concatenate(pts_ref, axis=0)).astype("float32"), - None, winSize=self.of_win_size, - maxLevel=self.of_max_level, - criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, - self.of_max_count, self.of_epsilon)) + pts_fs, status, err = cv2.calcOpticalFlowPyrLK( + self.last_img, + curr_img, + (np.concatenate(pts_ref, axis=0)).astype("float32"), + None, + winSize=self.of_win_size, + maxLevel=self.of_max_level, + criteria=( + cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, + self.of_max_count, + self.of_epsilon, + ), + ) self.last_img = curr_img # Split by instance @@ -266,22 +299,30 @@ def process(self, err = np.split(err, sections, axis=0) # Store shifted instances with metadata - shifted_instances = [ShiftedInstance(parent=ref, points=pts, frame=frame) - for ref, pts, found in zip(instances_ref, pts_fs, status) - if np.sum(found) > 0] + shifted_instances = [ + ShiftedInstance(parent=ref, points=pts, frame=frame) + for ref, pts, found in zip(instances_ref, pts_fs, status) + if np.sum(found) > 0 + ] # Get the track present in the shifted instances shifted_tracks = list({instance.track for instance in shifted_instances}) - last_known = self.tracks.get_last_known(curr_frame_index=t, max_shift=self.window) + last_known = self.tracks.get_last_known( + curr_frame_index=t, max_shift=self.window + ) alive_tracks = {i.track for i in last_known} # If we didn't get any shifted instances from the reference frame, use the last # know positions for each track. if len(shifted_instances) == 0: - logger.debug(f"[t = {t}] Optical flow failed, using last known positions for each track.") + logger.debug( + f"[t = {t}] Optical flow failed, using last known positions for each track." + ) shifted_instances = self.tracks.get_last_known() - shifted_tracks = list({instance.track for instance in shifted_instances}) + shifted_tracks = list( + {instance.track for instance in shifted_instances} + ) else: # We might have got some shifted instances, but make sure we aren't missing any # tracks @@ -298,21 +339,35 @@ def process(self, continue # Reduce distances by track - unassigned_pts = np.stack(instances_pts, axis=0) # instances x nodes x 2 - logger.debug(f"[t = {t}] Flow shift matching {len(unassigned_pts)} " - f"instances to {len(shifted_tracks)} ref tracks") + unassigned_pts = np.stack(instances_pts, axis=0) # instances x nodes x 2 + logger.debug( + f"[t = {t}] Flow shift matching {len(unassigned_pts)} " + f"instances to {len(shifted_tracks)} ref tracks" + ) cost_matrix = np.full((len(unassigned_pts), len(shifted_tracks)), np.nan) for i, track in enumerate(shifted_tracks): # Get shifted points for current track - track_pts = np.stack([instance.points_array() - for instance in shifted_instances - if instance.track == track], axis=0) # track_instances x nodes x 2 + track_pts = np.stack( + [ + instance.get_points_array() + for instance in shifted_instances + if instance.track == track + ], + axis=0, + ) # track_instances x nodes x 2 # Compute pairwise distances between points - distances = np.sqrt(np.sum((np.expand_dims(unassigned_pts / img_scale, axis=1) - - np.expand_dims(track_pts, axis=0)) ** 2, - axis=-1)) # unassigned_instances x track_instances x nodes + distances = np.sqrt( + np.sum( + ( + np.expand_dims(unassigned_pts / img_scale, axis=1) + - np.expand_dims(track_pts, axis=0) + ) + ** 2, + axis=-1, + ) + ) # unassigned_instances x track_instances x nodes # Reduce over nodes and instances distances = -np.nansum(np.exp(-distances), axis=(1, 2)) @@ -328,18 +383,23 @@ def process(self, frame.instances[i].track = shifted_tracks[j] self.tracks.add_instance(frame.instances[i], frame_index=t) - logger.debug(f"[t = {t}] Assigned instance {i} to existing track " - f"{shifted_tracks[j].name} (cost = {cost_matrix[i,j]})") + logger.debug( + f"[t = {t}] Assigned instance {i} to existing track " + f"{shifted_tracks[j].name} (cost = {cost_matrix[i,j]})" + ) # Spawn new tracks for unassigned instances for i, pts in enumerate(unassigned_pts): - if i in assigned_ind: continue + if i in assigned_ind: + continue instance = frame.instances[i] instance.track = Track(spawned_on=t, name=f"{len(self.tracks.tracks)}") self.tracks.add_instance(instance, frame_index=t) - logger.debug(f"[t = {t}] Assigned remaining instance {i} to newly " - f"spawned track {instance.track.name} " - f"(best cost = {cost_matrix[i,:].min()})") + logger.debug( + f"[t = {t}] Assigned remaining instance {i} to newly " + f"spawned track {instance.track.name} " + f"(best cost = {cost_matrix[i,:].min()})" + ) # Update the last know data structures for the last frame. self.tracks.update_track_last_known(labeled_frames[img_idx - 1], max_shift=None) @@ -353,9 +413,11 @@ def occupancy(self): occ = np.zeros((len(self.tracks.tracks), int(num_frames)), dtype="bool") for t in range(int(num_frames)): instances = self.tracks.get_frame_instances(t) - instances = [instance for instance in instances if isinstance(instance, Instance)] + instances = [ + instance for instance in instances if isinstance(instance, Instance) + ] for instance in instances: - occ[self.tracks.tracks.index(instance.track),t] = True + occ[self.tracks.tracks.index(instance.track), t] = True return occ @@ -370,22 +432,34 @@ def generate_tracks(self): instance_tracks = np.full((num_frames, num_nodes, 2, num_tracks), np.nan) for t in range(num_frames): instances = self.tracks.get_frame_instances(t) - instances = [instance for instance in instances if isinstance(instance, Instance)] + instances = [ + instance for instance in instances if isinstance(instance, Instance) + ] for instance in instances: - instance_tracks[t, :, :, self.tracks.tracks.index(instance.track)] = instance.points + instance_tracks[ + t, :, :, self.tracks.tracks.index(instance.track) + ] = instance.points return instance_tracks def generate_shifted_data(self): """ Generate arrays with all shifted instance data """ - shifted_instances = [y for x in self.tracks.instances.values() - for y in x if isinstance(y, ShiftedInstance)] + shifted_instances = [ + y + for x in self.tracks.instances.values() + for y in x + if isinstance(y, ShiftedInstance) + ] - track_id = np.array([self.tracks.tracks.index(instance.track) for instance in shifted_instances]) + track_id = np.array( + [self.tracks.tracks.index(instance.track) for instance in shifted_instances] + ) frame_idx = np.array([instance.frame_idx for instance in shifted_instances]) - frame_idx_source = np.array([instance.source.frame_idx for instance in shifted_instances]) + frame_idx_source = np.array( + [instance.source.frame_idx for instance in shifted_instances] + ) points = np.stack([instance.points for instance in shifted_instances], axis=0) return track_id, frame_idx, frame_idx_source, points diff --git a/sleap/nn/training.py b/sleap/nn/training.py index c400c80ff..a05f1e013 100644 --- a/sleap/nn/training.py +++ b/sleap/nn/training.py @@ -2,6 +2,7 @@ import json import logging + logger = logging.getLogger(__name__) import numpy as np @@ -20,8 +21,22 @@ from pathlib import Path, PureWindowsPath from keras import backend as K -from keras.layers import Input, Conv2D, BatchNormalization, Add, MaxPool2D, UpSampling2D, Concatenate -from keras.callbacks import ReduceLROnPlateau, EarlyStopping, TensorBoard, LambdaCallback, ModelCheckpoint +from keras.layers import ( + Input, + Conv2D, + BatchNormalization, + Add, + MaxPool2D, + UpSampling2D, + Concatenate, +) +from keras.callbacks import ( + ReduceLROnPlateau, + EarlyStopping, + TensorBoard, + LambdaCallback, + ModelCheckpoint, +) from sklearn.model_selection import train_test_split @@ -30,7 +45,14 @@ from sleap.nn.augmentation import Augmenter from sleap.nn.model import Model, ModelOutputType from sleap.nn.monitor import LossViewer -from sleap.nn.datagen import generate_training_data, generate_confmaps_from_points, generate_pafs_from_points, generate_images, generate_points, generate_centroid_points +from sleap.nn.datagen import ( + generate_training_data, + generate_confmaps_from_points, + generate_pafs_from_points, + generate_images, + generate_points, + generate_centroid_points, +) @attr.s(auto_attribs=True) @@ -130,18 +152,20 @@ class Trainer: sigma: float = 5.0 instance_crop: bool = False bounding_box_size: int = 0 - min_crop_size: int = 0 - negative_samples: int = 0 - - def train(self, - model: Model, - labels: Union[str, Labels, Dict], - run_name: str = None, - save_dir: Union[str, None] = None, - tensorboard_dir: Union[str, None] = None, - control_zmq_port: int = 9000, - progress_report_zmq_port: int = 9001, - multiprocessing_workers: int = 0) -> str: + min_crop_size: int = 32 + negative_samples: int = 10 + + def train( + self, + model: Model, + labels: Union[str, Labels, Dict], + run_name: str = None, + save_dir: Union[str, None] = None, + tensorboard_dir: Union[str, None] = None, + control_zmq_port: int = 9000, + progress_report_zmq_port: int = 9001, + multiprocessing_workers: int = 0, + ) -> str: """ Train a given model using labels and the Trainer's current hyper-parameter settings. This method executes synchronously, thus it blocks until training is finished. @@ -176,7 +200,6 @@ def train(self, elif type(labels) is dict: labels = Labels.from_json(labels) - # FIXME: We need to handle multiple skeletons. skeleton = labels.skeletons[0] @@ -187,22 +210,24 @@ def train(self, # Generate CENTROID training data if model.output_type == ModelOutputType.CENTROIDS: imgs = generate_images(labels, scale=self.scale) - points = generate_centroid_points( - generate_points(labels, scale=self.scale)) + points = generate_centroid_points(generate_points(labels, scale=self.scale)) # Generate REGULAR training data else: imgs, points = generate_training_data( - labels, - params = dict( - scale = self.scale, - instance_crop = self.instance_crop, - min_crop_size = self.min_crop_size, - negative_samples = self.negative_samples)) + labels, + params=dict( + scale=self.scale, + instance_crop=self.instance_crop, + min_crop_size=self.min_crop_size, + negative_samples=self.negative_samples, + ), + ) # Split data into train/validation - imgs_train, imgs_val, outputs_train, outputs_val = \ - train_test_split(imgs, points, test_size=self.val_size) + imgs_train, imgs_val, outputs_train, outputs_val = train_test_split( + imgs, points, test_size=self.val_size + ) # Free up the original datasets after test and train split. del imgs, points @@ -223,32 +248,36 @@ def train(self, num_outputs_channels = 1 # Determine input and output sizes - # If there are more downsampling layers than upsampling layers, # then the output (confidence maps or part affinity fields) will # be at a different scale than the input (images). - up_down_diff = model.backbone.down_blocks - model.backbone.up_blocks - output_scale = 1/(2**up_down_diff) - input_img_size = (imgs_train.shape[1], imgs_train.shape[2]) - output_img_size = (input_img_size[0]*output_scale, input_img_size[1]*output_scale) + output_img_size = ( + int(input_img_size[0] * model.output_scale), + int(input_img_size[1] * model.output_scale), + ) - logger.info(f"Training set: {imgs_train.shape} -> {output_img_size}, {num_outputs_channels} channels") - logger.info(f"Validation set: {imgs_val.shape} -> {output_img_size}, {num_outputs_channels} channels") + logger.info( + f"Training set: {imgs_train.shape} -> {output_img_size}, {num_outputs_channels} channels" + ) + logger.info( + f"Validation set: {imgs_val.shape} -> {output_img_size}, {num_outputs_channels} channels" + ) # Input layer img_input = Input((img_height, img_width, img_channels)) # Rectify image sizes not divisible by pooling factor - depth = getattr(model.backbone, 'depth', 0) - depth = depth or getattr(model.backbone, 'down_blocks', 0) + depth = getattr(model.backbone, "depth", 0) + depth = depth or getattr(model.backbone, "down_blocks", 0) if depth: pool_factor = 2 ** depth if img_height % pool_factor != 0 or img_width % pool_factor != 0: logger.warning( f"Image dimensions ({img_height}, {img_width}) are " - f"not divisible by the pooling factor ({pool_factor}).") + f"not divisible by the pooling factor ({pool_factor})." + ) # gap_height = (np.ceil(img_height / pool_factor) * pool_factor) - img_height # gap_width = (np.ceil(img_width / pool_factor) * pool_factor) - img_width @@ -258,24 +287,25 @@ def train(self, # Solution: https://www.tensorflow.org/api_docs/python/tf/pad + Lambda layer + corresponding crop at the end? # Instantiate the backbone, this builds the Tensorflow graph - x_outs = model.output(input_tesnor=img_input, num_output_channels=num_outputs_channels) + x_outs = model.output( + input_tensor=img_input, num_output_channels=num_outputs_channels + ) # Create training model by combining the input layer and backbone graph. keras_model = keras.Model(inputs=img_input, outputs=x_outs) # Specify the optimizer. if self.optimizer.lower() == "adam": - _optimizer = keras.optimizers.Adam(lr=self.learning_rate, amsgrad=self.amsgrad) + _optimizer = keras.optimizers.Adam( + lr=self.learning_rate, amsgrad=self.amsgrad + ) elif self.optimizer.lower() == "rmsprop": _optimizer = keras.optimizers.RMSprop(lr=self.learning_rate) else: - raise ValueError(f"Unknown optimizer, value = {optimizer}!") + raise ValueError(f"Unknown optimizer, value = {self.optimizer}!") # Compile the Keras model - keras_model.compile( - optimizer=_optimizer, - loss="mean_squared_error", - ) + keras_model.compile(optimizer=_optimizer, loss="mean_squared_error") logger.info("Params: {:,}".format(keras_model.count_params())) # Default to one loop through dataset per epoch @@ -288,17 +318,34 @@ def train(self, # Setup data generation if model.output_type == ModelOutputType.CONFIDENCE_MAP: + def datagen_function(points): - return generate_confmaps_from_points(points, skeleton, input_img_size, - sigma=self.sigma, scale=output_scale) + return generate_confmaps_from_points( + points, + skeleton, + input_img_size, + sigma=self.sigma, + scale=model.output_scale, + ) + elif model.output_type == ModelOutputType.PART_AFFINITY_FIELD: + def datagen_function(points): - return generate_pafs_from_points(points, skeleton, input_img_size, - sigma=self.sigma, scale=output_scale) + return generate_pafs_from_points( + points, + skeleton, + input_img_size, + sigma=self.sigma, + scale=model.output_scale, + ) + elif model.output_type == ModelOutputType.CENTROIDS: + def datagen_function(points): - return generate_confmaps_from_points(points, None, input_img_size, - node_count=1, sigma=self.sigma) + return generate_confmaps_from_points( + points, None, input_img_size, node_count=1, sigma=self.sigma + ) + else: datagen_function = None @@ -307,15 +354,23 @@ def datagen_function(points): # Initialize data generator with augmentation train_datagen = Augmenter( - imgs_train, points=outputs_train, - datagen=datagen_function, output_names=keras_model.output_names, - batch_size=self.batch_size, shuffle_initially=self.shuffle_initially, + imgs_train, + points=outputs_train, + datagen=datagen_function, + output_names=keras_model.output_names, + batch_size=self.batch_size, + shuffle_initially=self.shuffle_initially, rotation=self.augment_rotation, - scale=(self.augment_scale_min, self.augment_scale_max)) + scale=(self.augment_scale_min, self.augment_scale_max), + ) - train_run = TrainingJob(model=model, trainer=self, - save_dir=save_dir, run_name=run_name, - labels_filename=labels_file_name) + train_run = TrainingJob( + model=model, + trainer=self, + save_dir=save_dir, + run_name=run_name, + labels_filename=labels_file_name, + ) # Setup saving save_path = None @@ -323,8 +378,10 @@ def datagen_function(points): # Generate run name if run_name is None: timestamp = datetime.now().strftime("%y%m%d_%H%M%S") - train_run.run_name = f"{timestamp}.{str(model.output_type)}." \ - f"{model.name}.n={num_total}" + train_run.run_name = ( + f"{timestamp}.{str(model.output_type)}." + f"{model.name}.n={num_total}" + ) # Build save path save_path = os.path.join(save_dir, train_run.run_name) @@ -332,24 +389,38 @@ def datagen_function(points): # Check if it already exists if os.path.exists(save_path): - logger.warning(f"Save path already exists. " - f"Previous run data may be overwritten!") + logger.warning( + f"Save path already exists. " + f"Previous run data may be overwritten!" + ) # Create run folder os.makedirs(save_path, exist_ok=True) # Setup a list of necessary callbacks to invoke while training. + monitor_metric_name = "val_loss" + if len(keras_model.output_names) > 1: + monitor_metric_name = "val_" + keras_model.output_names[-1] + "_loss" callbacks = self._setup_callbacks( - train_run, save_path, train_datagen, - tensorboard_dir, control_zmq_port, - progress_report_zmq_port, output_type=str(model.output_type)) + train_run, + save_path, + train_datagen, + tensorboard_dir, + control_zmq_port, + progress_report_zmq_port, + output_type=str(model.output_type), + monitor_metric_name=monitor_metric_name, + ) # Train! history = keras_model.fit_generator( train_datagen, steps_per_epoch=steps_per_epoch, epochs=self.num_epochs, - validation_data=(imgs_val, {output_name: outputs_val for output_name in keras_model.output_names}), + validation_data=( + imgs_val, + {output_name: outputs_val for output_name in keras_model.output_names}, + ), callbacks=callbacks, verbose=2, use_multiprocessing=multiprocessing_workers > 0, @@ -359,7 +430,9 @@ def datagen_function(points): # Save once done training if save_path is not None: final_model_path = os.path.join(save_path, "final_model.h5") - keras_model.save(filepath=final_model_path, overwrite=True, include_optimizer=True) + keras_model.save( + filepath=final_model_path, overwrite=True, include_optimizer=True + ) logger.info(f"Saved final model: {final_model_path}") # TODO: save training history @@ -399,10 +472,17 @@ def train_async(self, *args, **kwargs) -> Tuple[Pool, AsyncResult]: return pool, result - def _setup_callbacks(self, train_run: 'TrainingJob', - save_path, train_datagen, - tensorboard_dir, control_zmq_port, - progress_report_zmq_port, output_type): + def _setup_callbacks( + self, + train_run: "TrainingJob", + save_path, + train_datagen, + tensorboard_dir, + control_zmq_port, + progress_report_zmq_port, + output_type, + monitor_metric_name="val_loss", + ): """ Setup callbacks for the call to Keras fit_generator. @@ -417,61 +497,98 @@ def _setup_callbacks(self, train_run: 'TrainingJob', if save_path is not None: if self.save_every_epoch: full_path = os.path.join(save_path, "newest_model.h5") - train_run.newest_model_filename = os.path.relpath(full_path, train_run.save_dir) + train_run.newest_model_filename = os.path.relpath( + full_path, train_run.save_dir + ) callbacks.append( - ModelCheckpoint(filepath=full_path, - monitor="val_loss", save_best_only=False, - save_weights_only=False, period=1)) + ModelCheckpoint( + filepath=full_path, + monitor=monitor_metric_name, + save_best_only=False, + save_weights_only=False, + period=1, + ) + ) if self.save_best_val: full_path = os.path.join(save_path, "best_model.h5") - train_run.best_model_filename = os.path.relpath(full_path, train_run.save_dir) + train_run.best_model_filename = os.path.relpath( + full_path, train_run.save_dir + ) callbacks.append( - ModelCheckpoint(filepath=full_path, - monitor="val_loss", save_best_only=True, - save_weights_only=False, period=1)) + ModelCheckpoint( + filepath=full_path, + monitor=monitor_metric_name, + save_best_only=True, + save_weights_only=False, + period=1, + ) + ) TrainingJob.save_json(train_run, f"{save_path}.json") # Callbacks: Shuffle after every epoch if self.shuffle_every_epoch: callbacks.append( - LambdaCallback(on_epoch_end=lambda epoch, logs: train_datagen.shuffle())) + LambdaCallback(on_epoch_end=lambda epoch, logs: train_datagen.shuffle()) + ) # Callbacks: LR reduction callbacks.append( - ReduceLROnPlateau(min_delta=self.reduce_lr_min_delta, - factor=self.reduce_lr_factor, - patience=self.reduce_lr_patience, - cooldown=self.reduce_lr_cooldown, - min_lr=self.reduce_lr_min_lr, - monitor="val_loss", mode="auto", verbose=1, ) + ReduceLROnPlateau( + min_delta=self.reduce_lr_min_delta, + factor=self.reduce_lr_factor, + patience=self.reduce_lr_patience, + cooldown=self.reduce_lr_cooldown, + min_lr=self.reduce_lr_min_lr, + monitor=monitor_metric_name, + mode="auto", + verbose=1, + ) ) # Callbacks: Early stopping callbacks.append( - EarlyStopping(monitor="val_loss", - min_delta=self.early_stopping_min_delta, - patience=self.early_stopping_patience, verbose=1)) + EarlyStopping( + monitor=monitor_metric_name, + min_delta=self.early_stopping_min_delta, + patience=self.early_stopping_patience, + verbose=1, + ) + ) # Callbacks: Tensorboard if tensorboard_dir is not None: callbacks.append( - TensorBoard(log_dir=f"{tensorboard_dir}/{output_type}{time()}", - batch_size=32, update_freq=150, histogram_freq=0, - write_graph=False, write_grads=False, write_images=False, - embeddings_freq=0, embeddings_layer_names=None, - embeddings_metadata=None, embeddings_data=None)) + TensorBoard( + log_dir=f"{tensorboard_dir}/{output_type}{time()}", + batch_size=32, + update_freq=150, + histogram_freq=0, + write_graph=False, + write_grads=False, + write_images=False, + embeddings_freq=0, + embeddings_layer_names=None, + embeddings_metadata=None, + embeddings_data=None, + ) + ) # Callbacks: ZMQ control if control_zmq_port is not None: callbacks.append( - TrainingControllerZMQ(address="tcp://127.0.0.1", - port=control_zmq_port, - topic="", poll_timeout=10)) + TrainingControllerZMQ( + address="tcp://127.0.0.1", + port=control_zmq_port, + topic="", + poll_timeout=10, + ) + ) # Callbacks: ZMQ progress reporter if progress_report_zmq_port is not None: callbacks.append( - ProgressReporterZMQ(port=progress_report_zmq_port, what=output_type)) + ProgressReporterZMQ(port=progress_report_zmq_port, what=output_type) + ) return callbacks @@ -498,6 +615,7 @@ class TrainingJob: from the final state of training. Set to None if save_dir is None. This model file is not created until training is finished. """ + model: Model trainer: Trainer labels_filename: Union[str, None] = None @@ -508,7 +626,7 @@ class TrainingJob: final_model_filename: Union[str, None] = None @staticmethod - def save_json(training_job: 'TrainingJob', filename: str): + def save_json(training_job: "TrainingJob", filename: str): """ Save a training run to a JSON file. @@ -520,7 +638,7 @@ def save_json(training_job: 'TrainingJob', filename: str): None """ - with open(filename, 'w') as file: + with open(filename, "w") as file: # We have some skeletons to deal with, make sure to setup a Skeleton cattr. my_cattr = Skeleton.make_cattr() @@ -528,7 +646,6 @@ def save_json(training_job: 'TrainingJob', filename: str): json_str = json.dumps(dicts) file.write(json_str) - @classmethod def load_json(cls, filename: str): """ @@ -542,26 +659,36 @@ def load_json(cls, filename: str): """ # Open and parse the JSON in filename - with open(filename, 'r') as file: - json_str = file.read() - dicts = json.loads(json_str) + with open(filename, "r") as f: + dicts = json.load(f) - # We have some skeletons to deal with, make sure to setup a Skeleton cattr. - my_cattr = Skeleton.make_cattr() + # We have some skeletons to deal with, make sure to setup a Skeleton cattr. + converter = Skeleton.make_cattr() + + # Structure the nested skeletons if we have any. + if ("model" in dicts) and ("skeletons" in dicts["model"]): + if dicts["model"]["skeletons"]: + dicts["model"]["skeletons"] = converter.structure( + dicts["model"]["skeletons"], List[Skeleton] + ) + + else: + dicts["model"]["skeletons"] = [] - try: - run = my_cattr.structure(dicts, cls) - except: - raise ValueError(f"Failure deserializing {filename} to TrainingJob.") + # Setup structuring hook for unambiguous backbone class resolution. + converter.register_structure_hook(Model, Model._structure_model) - # if we can't find save_dir for job, set it to path of json we're loading - if run.save_dir is not None: - if not os.path.exists(run.save_dir): - run.save_dir = os.path.dirname(filename) + # Build classes. + run = converter.structure(dicts, cls) - run.final_model_filename = cls._fix_path(run.final_model_filename) - run.best_model_filename = cls._fix_path(run.best_model_filename) - run.newest_model_filename = cls._fix_path(run.newest_model_filename) + # if we can't find save_dir for job, set it to path of json we're loading + if run.save_dir is not None: + if not os.path.exists(run.save_dir): + run.save_dir = os.path.dirname(filename) + + run.final_model_filename = cls._fix_path(run.final_model_filename) + run.best_model_filename = cls._fix_path(run.best_model_filename) + run.newest_model_filename = cls._fix_path(run.newest_model_filename) return run @@ -585,7 +712,9 @@ def __init__(self, address="tcp://127.0.0.1", port=9000, topic="", poll_timeout= self.socket = self.context.socket(zmq.SUB) self.socket.subscribe(self.topic) self.socket.connect(self.address) - logger.info(f"Training controller subscribed to: {self.address} (topic: {self.topic})") + logger.info( + f"Training controller subscribed to: {self.address} (topic: {self.topic})" + ) # TODO: catch/throw exception about failure to connect @@ -652,17 +781,26 @@ def on_train_begin(self, logs=None): logs: dict, currently no data is passed to this argument for this method but that may change in the future. """ - self.socket.send_string(jsonpickle.encode(dict(what=self.what,event="train_begin", logs=logs))) - + self.socket.send_string( + jsonpickle.encode(dict(what=self.what, event="train_begin", logs=logs)) + ) def on_batch_begin(self, batch, logs=None): """A backwards compatibility alias for `on_train_batch_begin`.""" # self.logger.info("batch_begin") - self.socket.send_string(jsonpickle.encode(dict(what=self.what,event="batch_begin", batch=batch, logs=logs))) + self.socket.send_string( + jsonpickle.encode( + dict(what=self.what, event="batch_begin", batch=batch, logs=logs) + ) + ) def on_batch_end(self, batch, logs=None): """A backwards compatibility alias for `on_train_batch_end`.""" - self.socket.send_string(jsonpickle.encode(dict(what=self.what,event="batch_end", batch=batch, logs=logs))) + self.socket.send_string( + jsonpickle.encode( + dict(what=self.what, event="batch_end", batch=batch, logs=logs) + ) + ) def on_epoch_begin(self, epoch, logs=None): """Called at the start of an epoch. @@ -673,7 +811,11 @@ def on_epoch_begin(self, epoch, logs=None): logs: dict, currently no data is passed to this argument for this method but that may change in the future. """ - self.socket.send_string(jsonpickle.encode(dict(what=self.what,event="epoch_begin", epoch=epoch, logs=logs))) + self.socket.send_string( + jsonpickle.encode( + dict(what=self.what, event="epoch_begin", epoch=epoch, logs=logs) + ) + ) def on_epoch_end(self, epoch, logs=None): """Called at the end of an epoch. @@ -685,7 +827,11 @@ def on_epoch_end(self, epoch, logs=None): validation epoch if validation is performed. Validation result keys are prefixed with `val_`. """ - self.socket.send_string(jsonpickle.encode(dict(what=self.what,event="epoch_end", epoch=epoch, logs=logs))) + self.socket.send_string( + jsonpickle.encode( + dict(what=self.what, event="epoch_end", epoch=epoch, logs=logs) + ) + ) def on_train_end(self, logs=None): """Called at the end of training. @@ -694,31 +840,48 @@ def on_train_end(self, logs=None): logs: dict, currently no data is passed to this argument for this method but that may change in the future. """ - self.socket.send_string(jsonpickle.encode(dict(what=self.what,event="train_end", logs=logs))) + self.socket.send_string( + jsonpickle.encode(dict(what=self.what, event="train_end", logs=logs)) + ) + def main(): from PySide2 import QtWidgets -# from sleap.nn.architectures.unet import UNet -# model = Model(output_type=ModelOutputType.CONFIDENCE_MAP, -# backbone=UNet(num_filters=16, depth=3, up_blocks=2)) + # from sleap.nn.architectures.unet import UNet + # model = Model(output_type=ModelOutputType.CONFIDENCE_MAP, + # backbone=UNet(num_filters=16, depth=3, up_blocks=2)) from sleap.nn.architectures.leap import LeapCNN - model = Model(output_type=ModelOutputType.PART_AFFINITY_FIELD, - backbone=LeapCNN(down_blocks=3, up_blocks=2, - upsampling_layers=True, num_filters=32, interp="bilinear")) + + model = Model( + output_type=ModelOutputType.PART_AFFINITY_FIELD, + backbone=LeapCNN( + down_blocks=3, + up_blocks=2, + upsampling_layers=True, + num_filters=32, + interp="bilinear", + ), + ) # Setup a Trainer object to train the model above - trainer = Trainer(val_size=0.1, batch_size=4, - num_epochs=10, steps_per_epoch=5, - save_best_val=True, - save_every_epoch=True) + trainer = Trainer( + val_size=0.1, + batch_size=4, + num_epochs=10, + steps_per_epoch=5, + save_best_val=True, + save_every_epoch=True, + ) # Run training asynchronously - pool, result = trainer.train_async(model=model, - labels=Labels.load_json("tests/data/json_format_v1/centered_pair.json"), - save_dir='test_train/', - run_name="training_run_2") + pool, result = trainer.train_async( + model=model, + labels=Labels.load_json("tests/data/json_format_v1/centered_pair.json"), + save_dir="test_train/", + run_name="training_run_2", + ) app = QtWidgets.QApplication() @@ -730,7 +893,7 @@ def main(): while not result.ready(): app.processEvents() - result.wait(.01) + result.wait(0.01) print("Get") train_job_path = result.get() @@ -742,13 +905,21 @@ def main(): # Now lets load the training job we just ran train_job = TrainingJob.load_json(train_job_path) - assert os.path.exists(os.path.join(train_job.save_dir, train_job.newest_model_filename)) - assert os.path.exists(os.path.join(train_job.save_dir, train_job.best_model_filename)) - assert os.path.exists(os.path.join(train_job.save_dir, train_job.final_model_filename)) + assert os.path.exists( + os.path.join(train_job.save_dir, train_job.newest_model_filename) + ) + assert os.path.exists( + os.path.join(train_job.save_dir, train_job.best_model_filename) + ) + assert os.path.exists( + os.path.join(train_job.save_dir, train_job.final_model_filename) + ) import sys + sys.exit(0) + def run(labels_filename: str, job_filename: str): labels = Labels.load_file(labels_filename) @@ -758,12 +929,13 @@ def run(labels_filename: str, job_filename: str): save_dir = os.path.join(os.path.dirname(labels_filename), "models") job.trainer.train( - model=job.model, - labels=labels, - save_dir=save_dir, - control_zmq_port=None, - progress_report_zmq_port=None - ) + model=job.model, + labels=labels, + save_dir=save_dir, + control_zmq_port=None, + progress_report_zmq_port=None, + ) + if __name__ == "__main__": import argparse @@ -777,7 +949,9 @@ def run(labels_filename: str, job_filename: str): job_filename = args.profile_path if not os.path.exists(job_filename): - profile_dir = resource_filename(Requirement.parse("sleap"), "sleap/training_profiles") + profile_dir = resource_filename( + Requirement.parse("sleap"), "sleap/training_profiles" + ) if os.path.exists(os.path.join(profile_dir, job_filename)): job_filename = os.path.join(profile_dir, job_filename) else: @@ -786,7 +960,4 @@ def run(labels_filename: str, job_filename: str): print(f"Training labels file: {args.labels_path}") print(f"Training profile: {job_filename}") - run( - labels_filename=args.labels_path, - job_filename=job_filename) - + run(labels_filename=args.labels_path, job_filename=job_filename) diff --git a/sleap/nn/transform.py b/sleap/nn/transform.py index 9389d70f5..89f71b77e 100644 --- a/sleap/nn/transform.py +++ b/sleap/nn/transform.py @@ -5,6 +5,7 @@ from sleap.nn.datagen import _bbs_from_points, _pad_bbs, _crop + @attr.s(auto_attribs=True, slots=True) class DataTransform: """ @@ -23,7 +24,9 @@ def _init_frame_idxs(self, frame_count): self.frame_idxs = list(range(frame_count)) def get_data_idxs(self, frame_idx): - return [i for i in range(len(self.frame_idxs)) if self.frame_idxs[i] == frame_idx] + return [ + i for i in range(len(self.frame_idxs)) if self.frame_idxs[i] == frame_idx + ] def get_frame_idxs(self, idxs): if type(idxs) == int: @@ -51,7 +54,7 @@ def scale_to(self, imgs, target_size): self._init_frame_idxs(img_count) # update object state (so we can invert) - self.scale = self.scale * (h/img_h) + self.scale = self.scale * (h / img_h) # return the scaled images return self._scale(imgs, target_size) @@ -67,7 +70,7 @@ def invert_scale(self, imgs): """ # determine target size for inverting scale img_count, img_h, img_w, img_channels = imgs.shape - target_size = (img_h * int(1/self.scale), img_w * int(1/self.scale)) + target_size = (img_h * int(1 / self.scale), img_w * int(1 / self.scale)) return self.scale_to(imgs, target_size) @@ -79,12 +82,14 @@ def _scale(self, imgs, target_size): if (img_h, img_w) != target_size: # build ndarray for new size - scaled_imgs = np.zeros((imgs.shape[0], h, w, imgs.shape[3])) + scaled_imgs = np.zeros( + (imgs.shape[0], h, w, imgs.shape[3]), dtype=imgs.dtype + ) for i in range(imgs.shape[0]): # resize using cv2 img = cv2.resize(imgs[i, :, :], (w, h)) - # add back singleton channel (removed by cv2) + # add back singleton channel (removed by cv2) if img_channels == 1: img = img[..., None] else: @@ -99,7 +104,7 @@ def _scale(self, imgs, target_size): return scaled_imgs - def centroid_crop(self, imgs: np.ndarray, centroids: list, crop_size: int=0): + def centroid_crop(self, imgs: np.ndarray, centroids: list, crop_size: int = 0): """ Crop images around centroid points. Updates state of DataTransform object so we can later invert on points. @@ -122,7 +127,7 @@ def centroid_crop(self, imgs: np.ndarray, centroids: list, crop_size: int=0): # Crop images return self.crop(imgs, bbs, idxs) - def crop(self, imgs:np.ndarray, boxes: list, idxs: list) -> np.ndarray: + def crop(self, imgs: np.ndarray, boxes: list, idxs: list) -> np.ndarray: """ Crop images to given boxes. @@ -172,7 +177,7 @@ def invert(self, idx: int, point_array: np.ndarray) -> np.ndarray: # translate point_array using corresponding bounding_box bb = self.bounding_boxes[idx] - top_left_point = ((bb[0], bb[1]),) # for (x, y) row vector + top_left_point = ((bb[0], bb[1]),) # for (x, y) row vector new_point_array += np.array(top_left_point) - return new_point_array \ No newline at end of file + return new_point_array diff --git a/sleap/nn/util.py b/sleap/nn/util.py index 9eb4647a7..b4cfd65ae 100644 --- a/sleap/nn/util.py +++ b/sleap/nn/util.py @@ -4,10 +4,13 @@ def batch_count(data, batch_size): """Return number of batch_size batches into which data can be divided.""" from math import ceil + return ceil(len(data) / batch_size) -def batch(data: Sequence, batch_size: int) -> Generator[Tuple[int, int, Sequence], None, None]: +def batch( + data: Sequence, batch_size: int +) -> Generator[Tuple[int, int, Sequence], None, None]: """Iterate over sequence data in batches. Arguments: @@ -18,10 +21,10 @@ def batch(data: Sequence, batch_size: int) -> Generator[Tuple[int, int, Sequence * batch number (int) * row offset (int) * batch_size number of items from data - """ + """ total_row_count = len(data) for start in range(0, total_row_count, batch_size): - i = start//batch_size + i = start // batch_size end = min(start + batch_size, total_row_count) yield i, start, data[start:end] @@ -36,7 +39,7 @@ def save_visual_outputs(output_path: str, data: dict): # output_path is full path to labels.json, so replace "json" with "h5" viz_output_path = output_path if viz_output_path.endswith(".json"): - viz_output_path = viz_output_path[:-(len(".json"))] + viz_output_path = viz_output_path[: -(len(".json"))] viz_output_path += ".h5" # write file @@ -45,11 +48,15 @@ def save_visual_outputs(output_path: str, data: dict): val = np.array(val) if key in f: f[key].resize(f[key].shape[0] + val.shape[0], axis=0) - f[key][-val.shape[0]:] = val + f[key][-val.shape[0] :] = val else: maxshape = (None, *val.shape[1:]) - f.create_dataset(key, data=val, maxshape=maxshape, - compression="gzip", compression_opts=9) + f.create_dataset( + key, + data=val, + maxshape=maxshape, + compression="gzip", + compression_opts=9, + ) # logger.info(" Saved visual outputs [%.1fs]" % (time() - t0)) - diff --git a/sleap/rangelist.py b/sleap/rangelist.py index 4c5e0701d..2ccc1175c 100644 --- a/sleap/rangelist.py +++ b/sleap/rangelist.py @@ -1,10 +1,20 @@ -class RangeList(): +""" +Module with RangeList class for manipulating a list of range intervals. + +This is used to cache the track occupancy so we can keep cache updating +when user manipulates tracks for a range of instances. +""" + +from typing import List, Tuple + + +class RangeList: """ Class for manipulating a list of range intervals. Each range interval in the list is a [start, end)-tuple. """ - def __init__(self, range_list: list=None): + def __init__(self, range_list: List[Tuple[int]] = None): self.list = range_list if range_list is not None else [] def __repr__(self): @@ -19,48 +29,53 @@ def list(self): def list(self, val): """Sets the list of ranges.""" self._list = val -# for i, r in enumerate(self._list): -# if type(r) == tuple: -# self._list[i] = range(r[0], r[1]) @property def is_empty(self): """Returns True if the list is empty.""" return len(self.list) == 0 + @property + def start(self): + """Return the start value of range (or None if empty).""" + if self.is_empty: + return None + return self.list[0][0] + def add(self, val, tolerance=0): - """Adds a single value, merges to last range if contiguous.""" - if len(self.list) and self.list[-1][1] + tolerance >= val: - self.list[-1] = (self.list[-1][0], val+1) + """Add a single value, merges to last range if contiguous.""" + if self.list and self.list[-1][1] + tolerance >= val: + self.list[-1] = (self.list[-1][0], val + 1) else: - self.list.append((val, val+1)) + self.list.append((val, val + 1)) def insert(self, new_range: tuple): - """Adds a new range, merging to adjacent/overlapping ranges as appropriate.""" + """Add a new range, merging to adjacent/overlapping ranges as appropriate.""" new_range = self._as_tuple(new_range) - pre, within, post = self.cut_range(new_range) + pre, _, post = self.cut_range(new_range) self.list = self.join_([pre, [new_range], post]) return self.list - def insert_list(self, range_list: list): - """Adds each range from a list of ranges.""" + def insert_list(self, range_list: List[Tuple[int]]): + """Add each range from a list of ranges.""" for range_ in range_list: self.insert(range_) return self.list def remove(self, remove: tuple): - """Removes everything that overlaps with given range.""" - pre, within, post = self.cut_range(remove) + """Remove everything that overlaps with given range.""" + pre, _, post = self.cut_range(remove) self.list = pre + post def cut(self, cut: int): - """Returns a pair of lists with everything before/after cut.""" + """Return a pair of lists with everything before/after cut.""" return self.cut_(self.list, cut) def cut_range(self, cut: tuple): - """Returns three lists, everthing before/within/after cut range.""" - if len(self.list) == 0: return [], [], [] + """Return three lists, everthing before/within/after cut range.""" + if not self.list: + return [], [], [] cut = self._as_tuple(cut) a, r = self.cut_(self.list, cut[0]) @@ -70,11 +85,20 @@ def cut_range(self, cut: tuple): @staticmethod def _as_tuple(x): - if type(x) == range: return x.start, x.stop + """Return tuple (converting from range if necessary).""" + if isinstance(x, range): + return x.start, x.stop return x @staticmethod - def cut_(range_list: list, cut: int): + def cut_(range_list: List[Tuple[int]], cut: int): + """Return a pair of lists with everything before/after cut. + Args: + range_list: the list to cut + cut: the value at which to cut list + Returns: + (pre-cut list, post-cut list)-tuple + """ pre = [] post = [] @@ -83,7 +107,7 @@ def cut_(range_list: list, cut: int): pre.append(range_) elif range_[0] >= cut: post.append(range_) - elif range_[0] < cut and range_[1] > cut: + elif range_[0] < cut < range_[1]: # two new ranges, split at cut a = (range_[0], cut) b = (cut, range_[1]) @@ -92,18 +116,29 @@ def cut_(range_list: list, cut: int): return pre, post @classmethod - def join_(cls, list_list: list): - if len(list_list) == 1: return list_list[0] - if len(list_list) == 2: return cls.join_pair_(list_list[0], list_list[1]) - else: return cls.join_pair_(list_list[0], cls.join_(list_list[1:])) + def join_(cls, list_list: List[List[Tuple[int]]]): + """Return a single list that includes all lists in input list. + + Args: + list_list: a list of range lists + Returns: + range list that joins all of the lists in list_list + """ + if len(list_list) == 1: + return list_list[0] + if len(list_list) == 2: + return cls.join_pair_(list_list[0], list_list[1]) + return cls.join_pair_(list_list[0], cls.join_(list_list[1:])) @staticmethod - def join_pair_(list_a: list, list_b: list): - if len(list_a) == 0 or len(list_b) == 0: return list_a + list_b - + def join_pair_(list_a: List[Tuple[int]], list_b: List[Tuple[int]]): + """Return a single pair of lists that joins two input lists.""" + if not list_a or not list_b: + return list_a + list_b + last_a = list_a[-1] first_b = list_b[0] if last_a[1] >= first_b[0]: return list_a[:-1] + [(last_a[0], first_b[1])] + list_b[1:] - else: - return list_a + list_b + + return list_a + list_b diff --git a/sleap/skeleton.py b/sleap/skeleton.py index d3450f460..d0c290396 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -1,9 +1,9 @@ -"""Implementation of skeleton data structure and API. - -This module implements and API for creating animal skeleton's in LEAP. The goal -is to provide a common interface for defining the parts of the animal, their -connection to each other, and needed meta-data. +""" +Implementation of skeleton data structure and API. +This module implements and API for creating animal skeletons. The goal +is to provide a common interface for defining the parts of the animal, +their connection to each other, and needed meta-data. """ import attr @@ -11,27 +11,33 @@ import numpy as np import jsonpickle import json -import networkx as nx import h5py as h5 import copy from enum import Enum from itertools import count -from typing import Iterable, Union, List, Dict +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +import networkx as nx from networkx.readwrite import json_graph -from scipy.io import loadmat, savemat +from scipy.io import loadmat + +NodeRef = Union[str, "Node"] +H5FileRef = Union[str, h5.File] class EdgeType(Enum): """ The skeleton graph can store different types of edges to represent - different things. All edges must specify one or more of the following types. + different things. All edges must specify one or more of the + following types: - * BODY - these edges represent connections between parts or landmarks. - * SYMMETRY - these edges represent symmetrical relationships between - parts (e.g. left and right arms) + * BODY - these edges represent connections between parts or + landmarks. + * SYMMETRY - these edges represent symmetrical relationships + between parts (e.g. left and right arms) """ + BODY = 1 SYMMETRY = 2 @@ -44,25 +50,27 @@ class Node: """ name: str - weight: float = 1. + weight: float = 1.0 @staticmethod - def from_names(name_list: str): + def from_names(name_list: str) -> List["Node"]: + """Convert list of node names to list of nodes objects.""" nodes = [] for name in name_list: nodes.append(Node(name)) return nodes @classmethod - def as_node(cls, node): + def as_node(cls, node: NodeRef) -> "Node": + """Convert given `node` to `Node` object (if not already).""" return node if isinstance(node, cls) else cls(node) - def matches(self, other): + def matches(self, other: "Node") -> bool: """ Check whether all attributes match between two nodes. Args: - other: The node to compare to this one. + other: The `Node` to compare to this one. Returns: True if all attributes match, False otherwise. @@ -71,40 +79,40 @@ def matches(self, other): class Skeleton: - """The main object for representing animal skeletons in LEAP. + """ + The main object for representing animal skeletons. - The skeleton represents the constituent parts of the animal whose pose - is being estimated. + The skeleton represents the constituent parts of the animal whose + pose is being estimated. + An index variable used to give skeletons a default name that should + be unique across all skeletons. """ - """ - A index variable used to give skeletons a default name that attemtpts to be - unique across all skeletons. - """ _skeleton_idx = count(0) def __init__(self, name: str = None): - """Initialize an empty skeleton object. + """ + Initialize an empty skeleton object. - Skeleton objects, once they are created can be modified by adding nodes and edges. + Skeleton objects, once created, can be modified by adding nodes + and edges. Args: name: A name for this skeleton. """ # If no skeleton was create, try to create a unique name for this Skeleton. - if name is None or type(name) is not str or len(name) == 0: + if name is None or not isinstance(name, str) or not name: name = "Skeleton-" + str(next(self._skeleton_idx)) - # Since networkx does not keep edges in the order we insert them we need # to keep track of how many edges have been inserted so we can number them # as they are inserted and sort them by this numbering when the edge list # is returned. self._graph: nx.MultiDiGraph = nx.MultiDiGraph(name=name, num_edges_inserted=0) - def matches(self, other: 'Skeleton'): + def matches(self, other: "Skeleton") -> bool: """ Compare this `Skeleton` to another, ignoring skeleton name and the identities of the `Node` objects in each graph. @@ -115,11 +123,14 @@ def matches(self, other: 'Skeleton'): Returns: True if match, False otherwise. """ + def dict_match(dict1, dict2): return dict1 == dict2 # Check if the graphs are iso-morphic - is_isomorphic = nx.is_isomorphic(self._graph, other._graph, node_match=dict_match) + is_isomorphic = nx.is_isomorphic( + self._graph, other._graph, node_match=dict_match + ) if not is_isomorphic: return False @@ -135,51 +146,72 @@ def dict_match(dict1, dict2): @property def graph(self): - edges = [(src, dst, key) for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") if edge_type == EdgeType.BODY] + """Returns subgraph of BODY edges for skeleton.""" + edges = [ + (src, dst, key) + for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") + if edge_type == EdgeType.BODY + ] # TODO: properly induce subgraph for MultiDiGraph - # Currently, NetworkX will just return the nodes in the subgraph. + # Currently, NetworkX will just return the nodes in the subgraph. # See: https://stackoverflow.com/questions/16150557/networkxcreating-a-subgraph-induced-from-edges return self._graph.edge_subgraph(edges) @property def graph_symmetry(self): - edges = [(src, dst, key) for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") if edge_type == EdgeType.SYMMETRY] + """Returns subgraph of symmetric edges for skeleton.""" + edges = [ + (src, dst, key) + for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") + if edge_type == EdgeType.SYMMETRY + ] return self._graph.edge_subgraph(edges) @staticmethod - def find_unique_nodes(skeletons: List['Skeleton']): + def find_unique_nodes(skeletons: List["Skeleton"]) -> List[Node]: """ - Given list of skeletons, return a list of unique node objects across all skeletons. + Find all unique nodes from a list of skeletons. Args: skeletons: The list of skeletons. Returns: - A list of unique node objects. + A list of unique `Node` objects. """ return list({node for skeleton in skeletons for node in skeleton.nodes}) @staticmethod - def make_cattr(idx_to_node: Dict[int, Node] = None): + def make_cattr(idx_to_node: Dict[int, Node] = None) -> cattr.Converter: """ - Create a cattr.Converter() that registers structure and unstructure hooks for - Skeleton objects that handle serialization of skeletons objects. + Make cattr.Convert() for `Skeleton`. + + Make a cattr.Converter() that registers structure/unstructure + hooks for Skeleton objects to handle serialization of skeletons. Args: idx_to_node: A dict that maps node index to Node objects. Returns: - A cattr.Converter() instance ready for skeleton serialization and deserialization. + A cattr.Converter() instance for skeleton serialization + and deserialization. """ - node_to_idx = {node:idx for idx,node in idx_to_node.items()} if idx_to_node is not None else None + node_to_idx = ( + {node: idx for idx, node in idx_to_node.items()} + if idx_to_node is not None + else None + ) _cattr = cattr.Converter() - _cattr.register_unstructure_hook(Skeleton, lambda x: Skeleton.to_dict(x, node_to_idx)) - _cattr.register_structure_hook(Skeleton, lambda x,type: Skeleton.from_dict(x, idx_to_node)) + _cattr.register_unstructure_hook( + Skeleton, lambda x: Skeleton.to_dict(x, node_to_idx) + ) + _cattr.register_structure_hook( + Skeleton, lambda x, cls: Skeleton.from_dict(x, idx_to_node) + ) return _cattr @property - def name(self): + def name(self) -> str: """Get the name of the skeleton. Returns: @@ -190,30 +222,38 @@ def name(self): @name.setter def name(self, name: str): """ - A skeleton object cannot change its name. This property is immutable because it is - used to hash skeletons. If you want to rename a Skeleton you must use the class + A skeleton object cannot change its name. + + This property is immutable because it is used to hash skeletons. + If you want to rename a Skeleton you must use the class method :code:`rename_skeleton`: - >>> new_skeleton = Skeleton.rename_skeleton(skeleton=old_skeleton, name="New Name") + >>> new_skeleton = Skeleton.rename_skeleton( + >>> skeleton=old_skeleton, name="New Name") Args: name: The name of the Skeleton. Raises: - NotImplementedError + NotImplementedError: Error is always raised. """ - raise NotImplementedError("Cannot change Skeleton name, it is immutable since " + - "it is used for hashing. Create a copy of the skeleton " + - "with new name using " + - f"new_skeleton = Skeleton.rename(skeleton, '{name}'))") + raise NotImplementedError( + "Cannot change Skeleton name, it is immutable since " + + "it is used for hashing. Create a copy of the skeleton " + + "with new name using " + + f"new_skeleton = Skeleton.rename(skeleton, '{name}'))" + ) @classmethod - def rename_skeleton(cls, skeleton: 'Skeleton', name: str) -> 'Skeleton': + def rename_skeleton(cls, skeleton: "Skeleton", name: str) -> "Skeleton": """ - A skeleton object cannot change its name. This property is immutable because it is - used to hash skeletons. If you want to rename a Skeleton you must use this classmethod. + Make copy of skeleton with new name. + + This property is immutable because it is used to hash skeletons. + If you want to rename a Skeleton you must use this class method. - >>> new_skeleton = Skeleton.rename_skeleton(skeleton=old_skeleton, name="New Name") + >>> new_skeleton = Skeleton.rename_skeleton( + >>> skeleton=old_skeleton, name="New Name") Args: skeleton: The skeleton to copy. @@ -228,7 +268,7 @@ def rename_skeleton(cls, skeleton: 'Skeleton', name: str) -> 'Skeleton': return new_skeleton @property - def nodes(self): + def nodes(self) -> List[Node]: """Get a list of :class:`Node`s. Returns: @@ -237,7 +277,7 @@ def nodes(self): return list(self._graph.nodes) @property - def node_names(self): + def node_names(self) -> List[str]: """Get a list of node names. Returns: @@ -246,15 +286,17 @@ def node_names(self): return [node.name for node in self.nodes] @property - def edges(self): + def edges(self) -> List[Tuple[Node, Node]]: """Get a list of edge tuples. Returns: list of (src_node, dst_node) """ - edge_list = [(d['edge_insert_idx'], src, dst) - for src, dst, key, d in self._graph.edges(keys=True, data=True) - if d['type'] == EdgeType.BODY] + edge_list = [ + (d["edge_insert_idx"], src, dst) + for src, dst, key, d in self._graph.edges(keys=True, data=True) + if d["type"] == EdgeType.BODY + ] # We don't want to return the edge list in the order it is stored. We # want to use the insertion order. Sort by the insertion index for each @@ -264,15 +306,17 @@ def edges(self): return edge_list @property - def edge_names(self): + def edge_names(self) -> List[Tuple[str, str]]: """Get a list of edge name tuples. Returns: list of (src_node.name, dst_node.name) """ - edge_list = [(d['edge_insert_idx'], src.name, dst.name) - for src, dst, key, d in self._graph.edges(keys=True, data=True) - if d['type'] == EdgeType.BODY] + edge_list = [ + (d["edge_insert_idx"], src.name, dst.name) + for src, dst, key, d in self._graph.edges(keys=True, data=True) + if d["type"] == EdgeType.BODY + ] # We don't want to return the edge list in the order it is stored. We # want to use the insertion order. Sort by the insertion index for each @@ -282,46 +326,62 @@ def edge_names(self): return [(src.name, dst.name) for src, dst in self.edges] @property - def edges_full(self): + def edges_full(self) -> List[Tuple[Node, Node, Any, Any]]: """Get a list of edge tuples with keys and attributes. Returns: list of (src_node, dst_node, key, attributes) """ - return [(src, dst, key, attr) for src, dst, key, attr in self._graph.edges(keys=True, data=True) if attr["type"] == EdgeType.BODY] + return [ + (src, dst, key, attr) + for src, dst, key, attr in self._graph.edges(keys=True, data=True) + if attr["type"] == EdgeType.BODY + ] @property - def symmetries(self): + def symmetries(self) -> List[Tuple[Node, Node]]: """Get a list of all symmetries without duplicates. Returns: list of (node1, node2) """ # Find all symmetric edges - symmetries = [(src, dst) for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") if edge_type == EdgeType.SYMMETRY] + symmetries = [ + (src, dst) + for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") + if edge_type == EdgeType.SYMMETRY + ] # Get rid of duplicates symmetries = list(set([tuple(set(e)) for e in symmetries])) return symmetries @property - def symmetries_full(self): + def symmetries_full(self) -> List[Tuple[Node, Node, Any, Any]]: """Get a list of all symmetries with keys and attributes. - Note: The returned list will contain duplicates (node1, node2) and (node2, node1). + Note: The returned list will contain duplicates (node1, node2) + and (node2, node1). Returns: list of (node1, node2, key, attr) """ # Find all symmetric edges - return [(src, dst, key, attr) for src, dst, key, attr in self._graph.edges(keys=True, data=True) if attr["type"] == EdgeType.SYMMETRY] + return [ + (src, dst, key, attr) + for src, dst, key, attr in self._graph.edges(keys=True, data=True) + if attr["type"] == EdgeType.SYMMETRY + ] - def node_to_index(self, node: Union[str, Node]): + def node_to_index(self, node: NodeRef) -> int: """ - Return the index of the node, accepts either a node or string name of a Node. + Return the index of the node, accepts either `Node` or name. Args: node: The name of the node or the Node object. + Raises: + ValueError if node cannot be found in skeleton. + Returns: The index of the node in the graph. """ @@ -335,12 +395,16 @@ def add_node(self, name: str): """Add a node representing an animal part to the skeleton. Args: - name: The name of the node to add to the skeleton. This name must be unique within the skeleton. + name: The name of the node to add to the skeleton. + This name must be unique within the skeleton. + + Raises: + ValueError: If name is not unique. Returns: None """ - if type(name) is not str: + if not isinstance(name, str): raise TypeError("Cannot add nodes to the skeleton that are not str") if name in self.node_names: @@ -348,7 +412,7 @@ def add_node(self, name: str): self._graph.add_node(Node(name)) - def add_nodes(self, name_list: list): + def add_nodes(self, name_list: List[str]): """ Add a list of nodes representing animal parts to the skeleton. @@ -364,10 +428,14 @@ def add_nodes(self, name_list: list): def delete_node(self, name: str): """Remove a node from the skeleton. - The method removes a node from the skeleton and any edge that is connected to it. + The method removes a node from the skeleton and any edge that is + connected to it. Args: - name: The name of the edge to remove + name: The name of the node to remove + + Raises: + ValueError: If node cannot be found. Returns: None @@ -376,26 +444,31 @@ def delete_node(self, name: str): node = self.find_node(name) self._graph.remove_node(node) except nx.NetworkXError: - raise ValueError("The node named ({}) does not exist, cannot remove it.".format(name)) + raise ValueError( + "The node named ({}) does not exist, cannot remove it.".format(name) + ) - def find_node(self, name: str): + def find_node(self, name: NodeRef) -> Node: """Find node in skeleton by name of node. Args: name: The name of the :class:`Node` (or a :class:`Node`) Returns: - Node, or None if no match found + `Node`, or None if no match found """ if isinstance(name, Node): name = name.name + nodes = [node for node in self.nodes if node.name == name] + if len(nodes) == 1: return nodes[0] - elif len(nodes) > 1: + + if len(nodes) > 1: raise ValueError("Found multiple nodes named ({}).".format(name)) - elif len(nodes) == 0: - return None + + return None def add_edge(self, source: str, destination: str): """Add an edge between two nodes. @@ -403,10 +476,12 @@ def add_edge(self, source: str, destination: str): Args: source: The name of the source node. destination: The name of the destination node. + Raises: + ValueError: If source or destination nodes cannot be found, + or if edge already exists between those nodes. Returns: - None - + None. """ if isinstance(source, Node): source_node = source @@ -421,17 +496,31 @@ def add_edge(self, source: str, destination: str): destination_node = self.find_node(destination) if source_node is None: - raise ValueError("Skeleton does not have source node named ({})".format(source)) + raise ValueError( + "Skeleton does not have source node named ({})".format(source) + ) if destination_node is None: - raise ValueError("Skeleton does not have destination node named ({})".format(destination)) + raise ValueError( + "Skeleton does not have destination node named ({})".format(destination) + ) if self._graph.has_edge(source_node, destination_node): - raise ValueError("Skeleton already has an edge between ({}) and ({}).".format(source, destination)) - - self._graph.add_edge(source_node, destination_node, type = EdgeType.BODY, - edge_insert_idx = self._graph.graph['num_edges_inserted']) - self._graph.graph['num_edges_inserted'] = self._graph.graph['num_edges_inserted'] + 1 + raise ValueError( + "Skeleton already has an edge between ({}) and ({}).".format( + source, destination + ) + ) + + self._graph.add_edge( + source_node, + destination_node, + type=EdgeType.BODY, + edge_insert_idx=self._graph.graph["num_edges_inserted"], + ) + self._graph.graph["num_edges_inserted"] = ( + self._graph.graph["num_edges_inserted"] + 1 + ) def delete_edge(self, source: str, destination: str): """Delete an edge between two nodes. @@ -440,6 +529,10 @@ def delete_edge(self, source: str, destination: str): source: The name of the source node. destination: The name of the destination node. + Raises: + ValueError: If skeleton does not have either source node, + destination node, or edge between them. + Returns: None """ @@ -456,26 +549,38 @@ def delete_edge(self, source: str, destination: str): destination_node = self.find_node(destination) if source_node is None: - raise ValueError("Skeleton does not have source node named ({})".format(source)) + raise ValueError( + "Skeleton does not have source node named ({})".format(source) + ) if destination_node is None: - raise ValueError("Skeleton does not have destination node named ({})".format(destination)) + raise ValueError( + "Skeleton does not have destination node named ({})".format(destination) + ) if not self._graph.has_edge(source_node, destination_node): - raise ValueError("Skeleton has no edge between ({}) and ({}).".format(source, destination)) + raise ValueError( + "Skeleton has no edge between ({}) and ({}).".format( + source, destination + ) + ) self._graph.remove_edge(source_node, destination_node) - def add_symmetry(self, node1:str, node2: str): - """Specify that two parts (nodes) in the skeleton are symmetrical. + def add_symmetry(self, node1: str, node2: str): + """Specify that two parts (nodes) in skeleton are symmetrical. - Certain parts of an animal body can be related as symmetrical parts in a pair. For example, - the left and right hands of a person. + Certain parts of an animal body can be related as symmetrical + parts in a pair. For example, left and right hands of a person. Args: node1: The name of the first part in the symmetric pair node2: The name of the second part in the symmetric pair + Raises: + ValueError: If node1 and node2 match, or if there is already + a symmetry between them. + Returns: None @@ -489,44 +594,70 @@ def add_symmetry(self, node1:str, node2: str): raise ValueError("Cannot add symmetry to the same node.") if self.get_symmetry(node1) is not None: - raise ValueError(f"{node1} is already symmetric with {self.get_symmetry(node1)}.") + raise ValueError( + f"{node1} is already symmetric with {self.get_symmetry(node1)}." + ) if self.get_symmetry(node2) is not None: - raise ValueError(f"{node2} is already symmetric with {self.get_symmetry(node2)}.") + raise ValueError( + f"{node2} is already symmetric with {self.get_symmetry(node2)}." + ) self._graph.add_edge(node1_node, node2_node, type=EdgeType.SYMMETRY) self._graph.add_edge(node2_node, node1_node, type=EdgeType.SYMMETRY) - def delete_symmetry(self, node1:str, node2: str): - """Deletes a previously established symmetry relationship between two nodes. + def delete_symmetry(self, node1: NodeRef, node2: NodeRef): + """ + Deletes a previously established symmetry between two nodes. Args: - node1: The name of the first part in the symmetric pair - node2: The name of the second part in the symmetric pair + node1: One node (by `Node` object or name) in symmetric pair. + node2: Other node in symmetric pair. + + Raises: + ValueError: If there's no symmetry between node1 and node2. Returns: None """ - node1_node, node1_node = self.find_node(node1), self.find_node(node2) + node1_node = self.find_node(node1) + node2_node = self.find_node(node2) - if self.get_symmetry(node1) != node2 or self.get_symmetry(node2) != node1: + if ( + self.get_symmetry(node1) != node2_node + or self.get_symmetry(node2) != node1_node + ): raise ValueError(f"Nodes {node1}, {node2} are not symmetric.") - edges = [(src, dst, key) for src, dst, key, edge_type in self._graph.edges([node1_node, node2_node], keys=True, data="type") if edge_type == EdgeType.SYMMETRY] + edges = [ + (src, dst, key) + for src, dst, key, edge_type in self._graph.edges( + [node1_node, node2_node], keys=True, data="type" + ) + if edge_type == EdgeType.SYMMETRY + ] self._graph.remove_edges_from(edges) - def get_symmetry(self, node:str): - """ Returns the node symmetric with the specified node. + def get_symmetry(self, node: NodeRef) -> Optional[Node]: + """ + Returns the node symmetric with the specified node. Args: - node: The name of the node to query. + node: Node (by `Node` object or name) to query. + + Raises: + ValueError: If node has more than one symmetry. Returns: - The symmetric :class:`Node`, None if no symmetry + The symmetric :class:`Node`, None if no symmetry. """ node_node = self.find_node(node) - symmetry = [dst for src, dst, edge_type in self._graph.edges(node_node, data="type") if edge_type == EdgeType.SYMMETRY] + symmetry = [ + dst + for src, dst, edge_type in self._graph.edges(node_node, data="type") + if edge_type == EdgeType.SYMMETRY + ] if len(symmetry) == 0: return None @@ -535,25 +666,29 @@ def get_symmetry(self, node:str): else: raise ValueError(f"{node} has more than one symmetry.") - def get_symmetry_name(self, node:str): - """ Returns the name of the node symmetric with the specified node. + def get_symmetry_name(self, node: NodeRef) -> Optional[str]: + """ + Returns the name of the node symmetric with the specified node. Args: - node: The name of the node to query. + node: Node (by `Node` object or name) to query. Returns: - name of symmetric node, None if no symmetry + Name of symmetric node, None if no symmetry. """ symmetric_node = self.get_symmetry(node) return None if symmetric_node is None else symmetric_node.name def __getitem__(self, node_name: str) -> dict: """ - Retrieves the node data associated with Skeleton node. + Retrieves the node data associated with skeleton node. Args: node_name: The name from which to retrieve data. + Raises: + ValueError: If node cannot be found. + Returns: A dictionary of data associated with this node. @@ -589,23 +724,27 @@ def relabel_node(self, old_name: str, new_name: str): """ self.relabel_nodes({old_name: new_name}) - def relabel_nodes(self, mapping:dict): + def relabel_nodes(self, mapping: Dict[str, str]): """ Relabel the nodes of the skeleton. Args: - mapping: A dictionary with the old labels as keys and new labels as values. A partial mapping is allowed. + mapping: A dictionary with the old labels as keys and new + labels as values. A partial mapping is allowed. + + Raises: + ValueError: If node already present with one of the new names. Returns: None """ existing_nodes = self.nodes - for k, v in mapping.items(): - if self.has_node(v): + for old_name, new_name in mapping.items(): + if self.has_node(new_name): raise ValueError("Cannot relabel a node to an existing name.") - node = self.find_node(k) + node = self.find_node(old_name) if node is not None: - node.name = v + node.name = new_name # self._graph = nx.relabel_nodes(G=self._graph, mapping=mapping) @@ -627,7 +766,7 @@ def has_nodes(self, names: Iterable[str]) -> bool: Check whether the skeleton has a list of nodes. Args: - name: The list names of the nodes to check for. + names: The list names of the nodes to check for. Returns: True for yes, False for no. @@ -652,33 +791,79 @@ def has_edge(self, source_name: str, dest_name: str) -> bool: True is yes, False if no. """ - source_node, destination_node = self.find_node(source_name), self.find_node(dest_name) + source_node, destination_node = ( + self.find_node(source_name), + self.find_node(dest_name), + ) return self._graph.has_edge(source_node, destination_node) @staticmethod - def to_dict(obj: 'Skeleton', node_to_idx: Dict[Node, int] = None): + def to_dict(obj: "Skeleton", node_to_idx: Optional[Dict[Node, int]] = None) -> Dict: + """ + Convert skeleton to dict; used for saving as JSON. + + Args: + obj: the :object:`Skeleton` to convert + node_to_idx: optional dict which maps :class:`Node`sto index + in some list. This is used when saving + :class:`Labels`where we want to serialize the + :class:`Nodes` outside the :class:`Skeleton` object. + If given, then we replace each :class:`Node` with + specified index before converting :class:`Skeleton`. + Otherwise, we convert :class:`Node` objects with the rest of + the :class:`Skeleton`. + Returns: + dict with data from skeleton + """ # This is a weird hack to serialize the whole _graph into a dict. # I use the underlying to_json and parse it. return json.loads(obj.to_json(node_to_idx)) @classmethod - def from_dict(cls, d: Dict, node_to_idx: Dict[Node, int] = None): + def from_dict(cls, d: Dict, node_to_idx: Dict[Node, int] = None) -> "Skeleton": + """ + Create skeleton from dict; used for loading from JSON. + + Args: + d: the dict from which to deserialize + node_to_idx: optional dict which maps :class:`Node`sto index + in some list. This is used when saving + :class:`Labels`where we want to serialize the + :class:`Nodes` outside the :class:`Skeleton` object. + If given, then we replace each :class:`Node` with + specified index before converting :class:`Skeleton`. + Otherwise, we convert :class:`Node` objects with the rest of + the :class:`Skeleton`. + + Returns: + :class:`Skeleton`. + + """ return Skeleton.from_json(json.dumps(d), node_to_idx) - def to_json(self, node_to_idx: Dict[Node, int] = None) -> str: + def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str: """ - Convert the skeleton to a JSON representation. + Convert the :class:`Skeleton` to a JSON representation. Args: - node_to_idx (optional): Map for converting `Node` nodes to int + node_to_idx: optional dict which maps :class:`Node`sto index + in some list. This is used when saving + :class:`Labels`where we want to serialize the + :class:`Nodes` outside the :class:`Skeleton` object. + If given, then we replace each :class:`Node` with + specified index before converting :class:`Skeleton`. + Otherwise, we convert :class:`Node` objects with the rest of + the :class:`Skeleton`. Returns: - A string containing the JSON representation of the Skeleton. + A string containing the JSON representation of the skeleton. """ - jsonpickle.set_encoder_options('simplejson', sort_keys=True, indent=4) + jsonpickle.set_encoder_options("simplejson", sort_keys=True, indent=4) if node_to_idx is not None: - indexed_node_graph = nx.relabel_nodes(G=self._graph, mapping=node_to_idx) # map nodes to int + indexed_node_graph = nx.relabel_nodes( + G=self._graph, mapping=node_to_idx + ) # map nodes to int else: indexed_node_graph = self._graph @@ -687,35 +872,50 @@ def to_json(self, node_to_idx: Dict[Node, int] = None) -> str: return json_str - def save_json(self, filename: str, node_to_idx: Dict[Node, int] = None): - """Save the skeleton as JSON file. + def save_json(self, filename: str, node_to_idx: Optional[Dict[Node, int]] = None): + """ + Save the :class:`Skeleton` as JSON file. - Output the complete skeleton to a file in JSON format. + Output the complete skeleton to a file in JSON format. - Args: - filename: The filename to save the JSON to. - node_to_idx (optional): Map for converting `Node` nodes to int + Args: + filename: The filename to save the JSON to. + node_to_idx: optional dict which maps :class:`Node`sto index + in some list. This is used when saving + :class:`Labels`where we want to serialize the + :class:`Nodes` outside the :class:`Skeleton` object. + If given, then we replace each :class:`Node` with + specified index before converting :class:`Skeleton`. + Otherwise, we convert :class:`Node` objects with the rest of + the :class:`Skeleton`. - Returns: - None - """ + Returns: + None + """ json_str = self.to_json(node_to_idx) - with open(filename, 'w') as file: + with open(filename, "w") as file: file.write(json_str) @classmethod - def from_json(cls, json_str: str, idx_to_node: Dict[int, Node] = None): + def from_json( + cls, json_str: str, idx_to_node: Dict[int, Node] = None + ) -> "Skeleton": """ - Parse a JSON string containing the Skeleton object and create an instance from it. + Instantiate :class:`Skeleton` from JSON string. Args: json_str: The JSON encoded Skeleton. - idx_to_node (optional): Map for converting int node in json back to corresponding `Node`. + idx_to_node: optional dict which maps an int (indexing a + list of :class:`Node` objects) to the already + deserialized :class:`Node`. + This should invert `node_to_idx` we used when saving. + If not given, then we'll assume each :class:`Node` was + left in the :class:`Skeleton` when it was saved. Returns: - An instance of the Skeleton object decoded from the JSON. + An instance of the `Skeleton` object decoded from the JSON. """ graph = json_graph.node_link_graph(jsonpickle.decode(json_str)) @@ -729,27 +929,36 @@ def from_json(cls, json_str: str, idx_to_node: Dict[int, Node] = None): return skeleton @classmethod - def load_json(cls, filename: str, idx_to_node: Dict[int, Node] = None): - """Load a skeleton from a JSON file. + def load_json( + cls, filename: str, idx_to_node: Dict[int, Node] = None + ) -> "Skeleton": + """ + Load a skeleton from a JSON file. - This method will load the Skeleton from JSON file saved with; :meth:`~Skeleton.save_json` + This method will load the Skeleton from JSON file saved + with; :meth:`~Skeleton.save_json` Args: - filename: The file that contains the JSON specifying the skeleton. - idx_to_node (optional): Map for converting int node in json back to corresponding `Node`. + filename: The file that contains the JSON. + idx_to_node: optional dict which maps an int (indexing a + list of :class:`Node` objects) to the already + deserialized :class:`Node`. + This should invert `node_to_idx` we used when saving. + If not given, then we'll assume each :class:`Node` was + left in the :class:`Skeleton` when it was saved. Returns: - The Skeleton object stored in the JSON filename. + The `Skeleton` object stored in the JSON filename. """ - with open(filename, 'r') as file: - skeleton = Skeleton.from_json(file.read(), idx_to_node) + with open(filename, "r") as file: + skeleton = cls.from_json(file.read(), idx_to_node) return skeleton @classmethod - def load_hdf5(cls, file: Union[str, h5.File], name: str): + def load_hdf5(cls, file: H5FileRef, name: str) -> List["Skeleton"]: """ Load a specific skeleton (by name) from the HDF5 file. @@ -758,68 +967,69 @@ def load_hdf5(cls, file: Union[str, h5.File], name: str): name: The name of the skeleton. Returns: - The skeleton instance stored in the HDF5 file. + The specified `Skeleton` instance stored in the HDF5 file. """ - if type(file) is str: + if isinstance(file, str): with h5.File(file) as _file: - skeletons = Skeleton._load_hdf5(_file) # Load all skeletons + skeletons = cls._load_hdf5(_file) # Load all skeletons else: - skeletons = Skeleton._load_hdf5(file) + skeletons = cls._load_hdf5(file) return skeletons[name] @classmethod - def load_all_hdf5(cls, file: Union[str, h5.File], - return_dict: bool = False) -> Union[List['Skeleton'], Dict[str, 'Skeleton']]: + def load_all_hdf5( + cls, file: H5FileRef, return_dict: bool = False + ) -> Union[List["Skeleton"], Dict[str, "Skeleton"]]: """ Load all skeletons found in the HDF5 file. Args: file: The file name or open h5.File - return_dict: True if the the return value should be a dict where the - keys are skeleton names and values the corresponding skeleton. False - if the return should just be a list of the skeletons. + return_dict: Whether the the return value should be a dict + where the keys are skeleton names and values the + corresponding skeleton. If False, then method will + return just a list of the skeletons. Returns: - The skeleton instances stored in the HDF5 file. Either in List or Dict form. + The skeleton instances stored in the HDF5 file. + Either in List or Dict form. """ - if type(file) is str: + if isinstance(file, str): with h5.File(file) as _file: - skeletons = Skeleton._load_hdf5(_file) # Load all skeletons + skeletons = cls._load_hdf5(_file) # Load all skeletons else: - skeletons = Skeleton._load_hdf5(file) + skeletons = cls._load_hdf5(file) if return_dict: return skeletons - else: - return list(skeletons.values()) + + return list(skeletons.values()) @classmethod def _load_hdf5(cls, file: h5.File): skeletons = {} - for name, json_str in file['skeleton'].attrs.items(): - skeletons[name] = Skeleton.from_json(json_str) + for name, json_str in file["skeleton"].attrs.items(): + skeletons[name] = cls.from_json(json_str) return skeletons - def save_hdf5(self, file: Union[str, h5.File]): - if type(file) is str: - with h5.File(file) as _file: - self._save_hdf5(_file) - else: - self._save_hdf5(file) - @classmethod - def save_all_hdf5(self, file: Union[str, h5.File], skeletons: List['Skeleton']): + def save_all_hdf5(self, file: H5FileRef, skeletons: List["Skeleton"]): """ - Convenience method to save a list of skeletons to HDF5 file. Skeletons are saved - as attributes of a /skeleton group in the file. + Convenience method to save a list of skeletons to HDF5 file. + + Skeletons are saved as attributes of a /skeleton group in the + file. Args: file: The filename or the open h5.File object. skeletons: The list of skeletons to save. + Raises: + ValueError: If multiple skeletons have the same name. + Returns: None """ @@ -833,32 +1043,51 @@ def save_all_hdf5(self, file: Union[str, h5.File], skeletons: List['Skeleton']): for skeleton in skeletons: skeleton.save_hdf5(file) + def save_hdf5(self, file: H5FileRef): + """ + Wrapper for HDF5 saving which takes either filename or h5.File. + + Args: + file: can be filename (string) or `h5.File` object + + Returns: + None + """ + + if isinstance(file, str): + with h5.File(file) as _file: + self._save_hdf5(_file) + else: + self._save_hdf5(file) + def _save_hdf5(self, file: h5.File): """ Actual implementation of HDF5 saving. Args: - file: The open h5.File to write the skeleton data too. + file: The open h5.File to write the skeleton data to. Returns: None """ # All skeleton will be put as sub-groups in the skeleton group - if 'skeleton' not in file: - all_sk_group = file.create_group('skeleton', track_order=True) + if "skeleton" not in file: + all_sk_group = file.create_group("skeleton", track_order=True) else: - all_sk_group = file.require_group('skeleton') + all_sk_group = file.require_group("skeleton") # Write the dataset to JSON string, then store it in a string # attribute all_sk_group.attrs[self.name] = np.string_(self.to_json()) @classmethod - def load_mat(cls, filename: str): + def load_mat(cls, filename: str) -> "Skeleton": """ - Load the skeleton from a Matlab MAT file. This is to support backwards - compatibility with old LEAP MATLAB code and datasets. + Load the skeleton from a Matlab MAT file. + + This is to support backwards compatibility with old LEAP + MATLAB code and datasets. Args: filename: The name of the skeleton file @@ -873,33 +1102,24 @@ def load_mat(cls, filename: str): skel_mat = loadmat(filename) skel_mat["nodes"] = skel_mat["nodes"][0][0] # convert to scalar - skel_mat["edges"] = skel_mat["edges"] - 1 # convert to 0-based indexing + skel_mat["edges"] = skel_mat["edges"] - 1 # convert to 0-based indexing - node_names = skel_mat['nodeNames'] + node_names = skel_mat["nodeNames"] node_names = [str(n[0][0]) for n in node_names] skeleton.add_nodes(node_names) for k in range(len(skel_mat["edges"])): edge = skel_mat["edges"][k] - skeleton.add_edge(source=node_names[edge[0]], destination=node_names[edge[1]]) + skeleton.add_edge( + source=node_names[edge[0]], destination=node_names[edge[1]] + ) return skeleton def __str__(self): return "%s(name=%r)" % (self.__class__.__name__, self.name) - def __eq__(self, other: 'Skeleton'): - - # First check names, duh! - if other.name != self.name: - return False - - # Then check if the graphs match - return self.matches(other) - def __hash__(self): """ - Construct a hash from skeleton name, which we force to be immutable so hashes - will not change. + Construct a hash from skeleton id. """ - return hash(self.name) - + return id(self) diff --git a/sleap/training_profiles/default_centroids.json b/sleap/training_profiles/default_centroids.json index 2e0f59b6d..9843f75b0 100644 --- a/sleap/training_profiles/default_centroids.json +++ b/sleap/training_profiles/default_centroids.json @@ -1 +1,48 @@ -{"model": {"output_type": 2, "backbone": {"down_blocks": 3, "up_blocks": 3, "convs_per_depth": 2, "num_filters": 16, "kernel_size": 5, "upsampling_layers": true, "interp": "bilinear"}, "skeletons": null, "backbone_name": "UNet"}, "trainer": {"val_size": 0.1, "optimizer": "adam", "learning_rate": 0.0001, "amsgrad": true, "batch_size": 4, "num_epochs": 100, "steps_per_epoch": 200, "shuffle_initially": true, "shuffle_every_epoch": true, "augment_rotation": 180, "augment_scale_min": 1.0, "augment_scale_max": 1.0, "save_every_epoch": false, "save_best_val": true, "reduce_lr_min_delta": 1e-06, "reduce_lr_factor": 0.5, "reduce_lr_patience": 5, "reduce_lr_cooldown": 3, "reduce_lr_min_lr": 1e-10, "early_stopping_min_delta": 1e-08, "early_stopping_patience": 15, "scale": 0.25, "sigma": 5.0, "instance_crop": false}, "labels_filename": null, "run_name": null, "save_dir": null, "best_model_filename": null, "newest_model_filename": null, "final_model_filename": null} \ No newline at end of file +{ + "model": { + "output_type": 2, + "backbone": { + "down_blocks": 3, + "up_blocks": 3, + "convs_per_depth": 2, + "num_filters": 16, + "kernel_size": 5, + "upsampling_layers": true, + "interp": "bilinear" + }, + "skeletons": null, + "backbone_name": "UNet" + }, + "trainer": { + "val_size": 0.1, + "optimizer": "adam", + "learning_rate": 0.0001, + "amsgrad": true, + "batch_size": 4, + "num_epochs": 100, + "steps_per_epoch": 200, + "shuffle_initially": true, + "shuffle_every_epoch": true, + "augment_rotation": 180, + "augment_scale_min": 1.0, + "augment_scale_max": 1.0, + "save_every_epoch": false, + "save_best_val": true, + "reduce_lr_min_delta": "1e-06", + "reduce_lr_factor": 0.5, + "reduce_lr_patience": 5, + "reduce_lr_cooldown": 3, + "reduce_lr_min_lr": "1e-10", + "early_stopping_min_delta": "1e-08", + "early_stopping_patience": 15, + "scale": 0.25, + "sigma": 5.0, + "instance_crop": false + }, + "labels_filename": null, + "run_name": null, + "save_dir": null, + "best_model_filename": null, + "newest_model_filename": null, + "final_model_filename": null +} \ No newline at end of file diff --git a/sleap/training_profiles/default_confmaps.json b/sleap/training_profiles/default_confmaps.json index 4503d7e8b..6d3393f0f 100644 --- a/sleap/training_profiles/default_confmaps.json +++ b/sleap/training_profiles/default_confmaps.json @@ -1 +1,50 @@ -{"model": {"output_type": 0, "backbone": {"down_blocks": 3, "up_blocks": 3, "convs_per_depth": 2, "num_filters": 32, "kernel_size": 5, "upsampling_layers": true, "interp": "bilinear"}, "skeletons": null, "backbone_name": "UNet"}, "trainer": {"val_size": 0.1, "optimizer": "adam", "learning_rate": 0.0001, "amsgrad": true, "batch_size": 2, "num_epochs": 150, "steps_per_epoch": 200, "shuffle_initially": true, "shuffle_every_epoch": true, "augment_rotation": 180, "augment_scale_min": 1.0, "augment_scale_max": 1.0, "save_every_epoch": false, "save_best_val": true, "reduce_lr_min_delta": 1e-06, "reduce_lr_factor": 0.5, "reduce_lr_patience": 5, "reduce_lr_cooldown": 3, "reduce_lr_min_lr": 1e-10, "early_stopping_min_delta": 1e-08, "early_stopping_patience": 15, "scale": 1, "sigma": 5.0, "instance_crop": true}, "labels_filename": null, "run_name": null, "save_dir": null, "best_model_filename": null, "newest_model_filename": null, "final_model_filename": null} \ No newline at end of file +{ + "model": { + "output_type": 0, + "backbone": { + "down_blocks": 3, + "up_blocks": 3, + "convs_per_depth": 2, + "num_filters": 32, + "kernel_size": 5, + "upsampling_layers": true, + "interp": "bilinear" + }, + "skeletons": null, + "backbone_name": "UNet" + }, + "trainer": { + "val_size": 0.1, + "optimizer": "adam", + "learning_rate": 0.0001, + "amsgrad": true, + "batch_size": 2, + "num_epochs": 150, + "steps_per_epoch": 200, + "shuffle_initially": true, + "shuffle_every_epoch": true, + "augment_rotation": 180, + "augment_scale_min": 1.0, + "augment_scale_max": 1.0, + "save_every_epoch": false, + "save_best_val": true, + "reduce_lr_min_delta": "1e-06", + "reduce_lr_factor": 0.5, + "reduce_lr_patience": 5, + "reduce_lr_cooldown": 3, + "reduce_lr_min_lr": "1e-10", + "early_stopping_min_delta": "1e-08", + "early_stopping_patience": 15, + "scale": 1, + "sigma": 5.0, + "instance_crop": true, + "min_crop_size": 32, + "negative_samples": 10 + }, + "labels_filename": null, + "run_name": null, + "save_dir": null, + "best_model_filename": null, + "newest_model_filename": null, + "final_model_filename": null +} \ No newline at end of file diff --git a/sleap/training_profiles/default_pafs.json b/sleap/training_profiles/default_pafs.json index 5c04a2acc..8cdd66dde 100644 --- a/sleap/training_profiles/default_pafs.json +++ b/sleap/training_profiles/default_pafs.json @@ -1 +1,48 @@ -{"model": {"output_type": 1, "backbone": {"down_blocks": 3, "up_blocks": 3, "upsampling_layers": true, "num_filters": 32, "interp": "bilinear"}, "skeletons": null, "backbone_name": "LeapCNN"}, "trainer": {"val_size": 0.15, "optimizer": "adam", "learning_rate": 5e-5, "amsgrad": true, "batch_size": 2, "num_epochs": 150, "steps_per_epoch": 100, "shuffle_initially": true, "shuffle_every_epoch": true, "augment_rotation": 180, "augment_scale_min": 1.0, "augment_scale_max": 1.0, "save_every_epoch": false, "save_best_val": true, "reduce_lr_min_delta": 1e-6, "reduce_lr_factor": 0.5, "reduce_lr_patience": 8, "reduce_lr_cooldown": 3, "reduce_lr_min_lr": 1e-10, "early_stopping_min_delta": 1e-08, "early_stopping_patience": 15, "scale": 1, "sigma": 5.0, "instance_crop": true}, "labels_filename": null, "run_name": null, "save_dir": null, "best_model_filename": null, "newest_model_filename": null, "final_model_filename": null} \ No newline at end of file +{ + "model": { + "output_type": 1, + "backbone": { + "down_blocks": 3, + "up_blocks": 3, + "upsampling_layers": true, + "num_filters": 32, + "interp": "bilinear" + }, + "skeletons": null, + "backbone_name": "LeapCNN" + }, + "trainer": { + "val_size": 0.15, + "optimizer": "adam", + "learning_rate": "5e-5", + "amsgrad": true, + "batch_size": 2, + "num_epochs": 150, + "steps_per_epoch": 100, + "shuffle_initially": true, + "shuffle_every_epoch": true, + "augment_rotation": 180, + "augment_scale_min": 1.0, + "augment_scale_max": 1.0, + "save_every_epoch": false, + "save_best_val": true, + "reduce_lr_min_delta": "1e-6", + "reduce_lr_factor": 0.5, + "reduce_lr_patience": 8, + "reduce_lr_cooldown": 3, + "reduce_lr_min_lr": "1e-10", + "early_stopping_min_delta": "1e-08", + "early_stopping_patience": 15, + "scale": 1, + "sigma": 5.0, + "instance_crop": true, + "min_crop_size": 32, + "negative_samples": 10 + }, + "labels_filename": null, + "run_name": null, + "save_dir": null, + "best_model_filename": null, + "newest_model_filename": null, + "final_model_filename": null +} \ No newline at end of file diff --git a/sleap/util.py b/sleap/util.py index 289c33ff2..d91deadfc 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -2,40 +2,94 @@ A miscellaneous set of utility functions. Try not to put things in here unless they really have no other place. """ + import os +import re import h5py as h5 import numpy as np import attr import psutil +import json +import rapidjson + +from typing import Any, Dict, Hashable, Iterable, List, Optional + + +def json_loads(json_str: str) -> Dict: + """ + A simple wrapper around the JSON decoder we are using. + + Args: + json_str: JSON string to decode. + + Returns: + Result of decoding JSON string. + """ + try: + return rapidjson.loads(json_str) + except: + return json.loads(json_str) + + +def json_dumps(d: Dict, filename: str = None): + """ + A simple wrapper around the JSON encoder we are using. + + Args: + d: The dict to write. + filename: The filename to write to. + + Returns: + None + """ + + encoder = rapidjson -from typing import Callable + if filename: + with open(filename, "w") as f: + encoder.dump(d, f, ensure_ascii=False) + else: + return encoder.dumps(d) -def attr_to_dtype(cls): +def attr_to_dtype(cls: Any): + """ + Converts classes with basic types to numpy composite dtypes. + + Arguments: + cls: class to convert + + Returns: + numpy dtype. + """ dtype_list = [] for field in attr.fields(cls): if field.type == str: dtype_list.append((field.name, h5.special_dtype(vlen=str))) elif field.type is None: - raise TypeError(f"numpy dtype for {cls} cannot be constructed because no " + - "type information found. Make sure each field is type annotated.") + raise TypeError( + f"numpy dtype for {cls} cannot be constructed because no " + + "type information found. Make sure each field is type annotated." + ) elif field.type in [str, int, float, bool]: dtype_list.append((field.name, field.type)) else: - raise TypeError(f"numpy dtype for {cls} cannot be constructed because no " + - f"{field.type} is not supported.") + raise TypeError( + f"numpy dtype for {cls} cannot be constructed because no " + + f"{field.type} is not supported." + ) return np.dtype(dtype_list) def usable_cpu_count() -> int: - """Get number of CPUs usable by the current process. + """ + Gets number of CPUs usable by the current process. Takes into consideration cpusets restrictions. - Returns - ------- + Returns: The number of usable cpus """ try: @@ -47,17 +101,25 @@ def usable_cpu_count() -> int: result = os.cpu_count() return result + def save_dict_to_hdf5(h5file: h5.File, path: str, dic: dict): """ - Saves dictionary to an HDF5 file, calls itself recursively if items in - dictionary are not np.ndarray, np.int64, np.float64, str, bytes. Objects - must be iterable. + Saves dictionary to an HDF5 file. + + Calls itself recursively if items in dictionary are not + `np.ndarray`, `np.int64`, `np.float64`, `str`, or bytes. + Objects must be iterable. Args: - h5file: The HDF5 filename object to save the data to. Assume it is open. + h5file: The HDF5 filename object to save the data to. + Assume it is open. path: The path to group save the dict under. dic: The dict to save. + Raises: + ValueError: If type for item in dict cannot be saved. + + Returns: None """ @@ -71,37 +133,47 @@ def save_dict_to_hdf5(h5file: h5.File, path: str, dic: dict): items_encoded = [] for it in item: if isinstance(it, str): - items_encoded.append(it.encode('utf8')) + items_encoded.append(it.encode("utf8")) else: items_encoded.append(it) h5file[path + key] = np.asarray(items_encoded) elif isinstance(item, (str)): - h5file[path + key] = item.encode('utf8') + h5file[path + key] = item.encode("utf8") elif isinstance(item, (np.ndarray, np.int64, np.float64, str, bytes, float)): h5file[path + key] = item elif isinstance(item, dict): - save_dict_to_hdf5(h5file, path + key + '/', item) + save_dict_to_hdf5(h5file, path + key + "/", item) elif isinstance(item, int): h5file[path + key] = item else: - raise ValueError('Cannot save %s type'%type(item)) + raise ValueError("Cannot save %s type" % type(item)) + + +def frame_list(frame_str: str) -> Optional[List[int]]: + """ + Converts 'n-m' string to list of ints. -def frame_list(frame_str: str): + Args: + frame_str: string representing range + + Returns: + List of ints, or None if string does not represent valid range. + """ # Handle ranges of frames. Must be of the form "1-200" - if '-' in frame_str: - min_max = frame_str.split('-') + if "-" in frame_str: + min_max = frame_str.split("-") min_frame = int(min_max[0]) max_frame = int(min_max[1]) - return list(range(min_frame, max_frame+1)) + return list(range(min_frame, max_frame + 1)) return [int(x) for x in frame_str.split(",")] if len(frame_str) else None -def uniquify(seq): +def uniquify(seq: Iterable[Hashable]) -> List: """ - Given a list, return unique elements but preserve order. + Returns unique elements from list, preserving order. Note: This will not work on Python 3.5 or lower since dicts don't preserve order. @@ -110,9 +182,52 @@ def uniquify(seq): seq: The list to remove duplicates from. Returns: - The unique elements from the input list extracted in original order. + The unique elements from the input list extracted in original + order. """ # Raymond Hettinger # https://twitter.com/raymondh/status/944125570534621185 - return list(dict.fromkeys(seq)) \ No newline at end of file + return list(dict.fromkeys(seq)) + + +def weak_filename_match(filename_a: str, filename_b: str) -> bool: + """ + Check if paths probably point to same file. + + Compares the filename and names of two directories up. + + Args: + filename_a: first path to check + filename_b: path to check against first path + + Returns: + True if the paths probably match. + """ + # convert all path separators to / + filename_a = filename_a.replace("\\", "/") + filename_b = filename_b.replace("\\", "/") + + # remove unique pid so we can match tmp directories for same zip + filename_a = re.sub("/tmp_\d+_", "tmp_", filename_a) + filename_b = re.sub("/tmp_\d+_", "tmp_", filename_b) + + # check if last three parts of path match + return filename_a.split("/")[-3:] == filename_b.split("/")[-3:] + + +def dict_cut(d: Dict, a: int, b: int) -> Dict: + """ + Helper function for creating subdictionary by numeric indexing of items. + + Assumes that `dict.items()` will have a fixed order. + + Args: + d: The dictionary to "split" + a: Start index of range of items to include in result. + b: End index of range of items to include in result. + + Returns: + A dictionary that contains a subset of the items in the original dict. + """ + return dict(list(d.items())[a:b]) diff --git a/tests/conftest.py b/tests/conftest.py index 52b682e44..8c850b0ff 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ try: import pytestqt except: - logging.warning('Could not import PySide2 or pytestqt, skipping GUI tests.') + logging.warning("Could not import PySide2 or pytestqt, skipping GUI tests.") collect_ignore_glob = ["gui/*"] from tests.fixtures.skeletons import * diff --git a/tests/data/training_profiles/set_a/default_centroids.json b/tests/data/training_profiles/set_a/default_centroids.json new file mode 100644 index 000000000..2e0f59b6d --- /dev/null +++ b/tests/data/training_profiles/set_a/default_centroids.json @@ -0,0 +1 @@ +{"model": {"output_type": 2, "backbone": {"down_blocks": 3, "up_blocks": 3, "convs_per_depth": 2, "num_filters": 16, "kernel_size": 5, "upsampling_layers": true, "interp": "bilinear"}, "skeletons": null, "backbone_name": "UNet"}, "trainer": {"val_size": 0.1, "optimizer": "adam", "learning_rate": 0.0001, "amsgrad": true, "batch_size": 4, "num_epochs": 100, "steps_per_epoch": 200, "shuffle_initially": true, "shuffle_every_epoch": true, "augment_rotation": 180, "augment_scale_min": 1.0, "augment_scale_max": 1.0, "save_every_epoch": false, "save_best_val": true, "reduce_lr_min_delta": 1e-06, "reduce_lr_factor": 0.5, "reduce_lr_patience": 5, "reduce_lr_cooldown": 3, "reduce_lr_min_lr": 1e-10, "early_stopping_min_delta": 1e-08, "early_stopping_patience": 15, "scale": 0.25, "sigma": 5.0, "instance_crop": false}, "labels_filename": null, "run_name": null, "save_dir": null, "best_model_filename": null, "newest_model_filename": null, "final_model_filename": null} \ No newline at end of file diff --git a/tests/data/training_profiles/set_a/default_confmaps.json b/tests/data/training_profiles/set_a/default_confmaps.json new file mode 100644 index 000000000..4503d7e8b --- /dev/null +++ b/tests/data/training_profiles/set_a/default_confmaps.json @@ -0,0 +1 @@ +{"model": {"output_type": 0, "backbone": {"down_blocks": 3, "up_blocks": 3, "convs_per_depth": 2, "num_filters": 32, "kernel_size": 5, "upsampling_layers": true, "interp": "bilinear"}, "skeletons": null, "backbone_name": "UNet"}, "trainer": {"val_size": 0.1, "optimizer": "adam", "learning_rate": 0.0001, "amsgrad": true, "batch_size": 2, "num_epochs": 150, "steps_per_epoch": 200, "shuffle_initially": true, "shuffle_every_epoch": true, "augment_rotation": 180, "augment_scale_min": 1.0, "augment_scale_max": 1.0, "save_every_epoch": false, "save_best_val": true, "reduce_lr_min_delta": 1e-06, "reduce_lr_factor": 0.5, "reduce_lr_patience": 5, "reduce_lr_cooldown": 3, "reduce_lr_min_lr": 1e-10, "early_stopping_min_delta": 1e-08, "early_stopping_patience": 15, "scale": 1, "sigma": 5.0, "instance_crop": true}, "labels_filename": null, "run_name": null, "save_dir": null, "best_model_filename": null, "newest_model_filename": null, "final_model_filename": null} \ No newline at end of file diff --git a/tests/data/training_profiles/set_a/default_pafs.json b/tests/data/training_profiles/set_a/default_pafs.json new file mode 100644 index 000000000..5c04a2acc --- /dev/null +++ b/tests/data/training_profiles/set_a/default_pafs.json @@ -0,0 +1 @@ +{"model": {"output_type": 1, "backbone": {"down_blocks": 3, "up_blocks": 3, "upsampling_layers": true, "num_filters": 32, "interp": "bilinear"}, "skeletons": null, "backbone_name": "LeapCNN"}, "trainer": {"val_size": 0.15, "optimizer": "adam", "learning_rate": 5e-5, "amsgrad": true, "batch_size": 2, "num_epochs": 150, "steps_per_epoch": 100, "shuffle_initially": true, "shuffle_every_epoch": true, "augment_rotation": 180, "augment_scale_min": 1.0, "augment_scale_max": 1.0, "save_every_epoch": false, "save_best_val": true, "reduce_lr_min_delta": 1e-6, "reduce_lr_factor": 0.5, "reduce_lr_patience": 8, "reduce_lr_cooldown": 3, "reduce_lr_min_lr": 1e-10, "early_stopping_min_delta": 1e-08, "early_stopping_patience": 15, "scale": 1, "sigma": 5.0, "instance_crop": true}, "labels_filename": null, "run_name": null, "save_dir": null, "best_model_filename": null, "newest_model_filename": null, "final_model_filename": null} \ No newline at end of file diff --git a/tests/data/training_profiles/set_b/test_confmaps.json b/tests/data/training_profiles/set_b/test_confmaps.json new file mode 100644 index 000000000..2245a173c --- /dev/null +++ b/tests/data/training_profiles/set_b/test_confmaps.json @@ -0,0 +1 @@ +{"model": {"output_type": 0, "backbone": {"down_blocks": 3, "up_blocks": 3, "convs_per_depth": 2, "num_filters": 32, "kernel_size": 5, "upsampling_layers": true, "interp": "bilinear"}, "skeletons": null, "backbone_name": "UNet"}, "trainer": {"num_epochs": 17}} \ No newline at end of file diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index ec370a899..9bb5db7ea 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -1,14 +1,24 @@ import os import pytest -from sleap.instance import Instance, Point, LabeledFrame, Track +from sleap.instance import ( + Instance, + PredictedInstance, + Point, + PredictedPoint, + LabeledFrame, + Track, +) +from sleap.skeleton import Skeleton from sleap.io.dataset import Labels +from sleap.io.video import Video TEST_JSON_LABELS = "tests/data/json_format_v1/centered_pair.json" TEST_JSON_PREDICTIONS = "tests/data/json_format_v2/centered_pair_predictions.json" TEST_JSON_MIN_LABELS = "tests/data/json_format_v2/minimal_instance.json" TEST_MAT_LABELS = "tests/data/mat/labels.mat" + @pytest.fixture def centered_pair_labels(): return Labels.load_json(TEST_JSON_LABELS) @@ -18,14 +28,84 @@ def centered_pair_labels(): def centered_pair_predictions(): return Labels.load_json(TEST_JSON_PREDICTIONS) + @pytest.fixture def min_labels(): return Labels.load_json(TEST_JSON_MIN_LABELS) + @pytest.fixture def mat_labels(): return Labels.load_mat(TEST_MAT_LABELS) + +@pytest.fixture +def simple_predictions(): + + video = Video.from_filename("video.mp4") + + skeleton = Skeleton() + skeleton.add_node("a") + skeleton.add_node("b") + + track_a = Track(0, "a") + track_b = Track(0, "b") + + labels = Labels() + + instances = [] + instances.append( + PredictedInstance( + skeleton=skeleton, + score=2, + track=track_a, + points=dict( + a=PredictedPoint(1, 1, score=0.5), b=PredictedPoint(1, 1, score=0.5) + ), + ) + ) + instances.append( + PredictedInstance( + skeleton=skeleton, + score=5, + track=track_b, + points=dict( + a=PredictedPoint(1, 1, score=0.7), b=PredictedPoint(1, 1, score=0.7) + ), + ) + ) + + labeled_frame = LabeledFrame(video, frame_idx=0, instances=instances) + labels.append(labeled_frame) + + instances = [] + instances.append( + PredictedInstance( + skeleton=skeleton, + score=3, + track=track_a, + points=dict( + a=PredictedPoint(4, 5, score=1.5), b=PredictedPoint(1, 1, score=1.0) + ), + ) + ) + instances.append( + PredictedInstance( + skeleton=skeleton, + score=6, + track=track_b, + points=dict( + a=PredictedPoint(6, 13, score=1.7), b=PredictedPoint(1, 1, score=1.0) + ), + ) + ) + + labeled_frame = LabeledFrame(video, frame_idx=1, instances=instances) + labels.append(labeled_frame) + + return labels + + @pytest.fixture def multi_skel_vid_labels(hdf5_vid, small_robot_mp4_vid, skeleton, stickman): """ @@ -60,7 +140,9 @@ def multi_skel_vid_labels(hdf5_vid, small_robot_mp4_vid, skeleton, stickman): stickman_instances = [] for i in range(6): - stickman_instances.append(Instance(skeleton=stickman, track=stick_tracks[i])) + stickman_instances.append( + Instance(skeleton=stickman, track=stick_tracks[i]) + ) for node in stickman.nodes: stickman_instances[i][node] = Point(x=i % vid.width, y=i % vid.height) @@ -70,4 +152,3 @@ def multi_skel_vid_labels(hdf5_vid, small_robot_mp4_vid, skeleton, stickman): labels = Labels(labels) return labels - diff --git a/tests/fixtures/instances.py b/tests/fixtures/instances.py index dcac5ce29..862577457 100644 --- a/tests/fixtures/instances.py +++ b/tests/fixtures/instances.py @@ -12,21 +12,23 @@ def instances(skeleton): instances = [] for i in range(NUM_INSTANCES): instance = Instance(skeleton=skeleton) - instance['head'] = Point(i*1, i*2) - instance['left-wing'] = Point(10 + i * 1, 10 + i * 2) - instance['right-wing'] = Point(20 + i * 1, 20 + i * 2) + instance["head"] = Point(i * 1, i * 2) + instance["left-wing"] = Point(10 + i * 1, 10 + i * 2) + instance["right-wing"] = Point(20 + i * 1, 20 + i * 2) # Lets make an NaN entry to test skip_nan as well - instance['thorax'] + instance["thorax"] instances.append(instance) return instances + @pytest.fixture def predicted_instances(instances): return [PredictedInstance.from_instance(i, 1.0) for i in instances] + @pytest.fixture def multi_skel_instances(skeleton, stickman): """ @@ -39,21 +41,21 @@ def multi_skel_instances(skeleton, stickman): instances = [] for i in range(NUM_INSTANCES): instance = Instance(skeleton=skeleton, video=None, frame_idx=i) - instance['head'] = Point(i*1, i*2) - instance['left-wing'] = Point(10 + i * 1, 10 + i * 2) - instance['right-wing'] = Point(20 + i * 1, 20 + i * 2) + instance["head"] = Point(i * 1, i * 2) + instance["left-wing"] = Point(10 + i * 1, 10 + i * 2) + instance["right-wing"] = Point(20 + i * 1, 20 + i * 2) # Lets make an NaN entry to test skip_nan as well - instance['thorax'] + instance["thorax"] instances.append(instance) # Setup some instances of the stick man on the same frames for i in range(NUM_INSTANCES): instance = Instance(skeleton=stickman, video=None, frame_idx=i) - instance['head'] = Point(i * 10, i * 20) - instance['body'] = Point(100 + i * 1, 100 + i * 2) - instance['left-arm'] = Point(200 + i * 1, 200 + i * 2) + instance["head"] = Point(i * 10, i * 20) + instance["body"] = Point(100 + i * 1, 100 + i * 2) + instance["left-arm"] = Point(200 + i * 1, 200 + i * 2) instances.append(instance) diff --git a/tests/fixtures/skeletons.py b/tests/fixtures/skeletons.py index 13a2b741f..c340270bb 100644 --- a/tests/fixtures/skeletons.py +++ b/tests/fixtures/skeletons.py @@ -2,23 +2,27 @@ from sleap.skeleton import Skeleton + @pytest.fixture def stickman(): # Make a skeleton with a space in its name to test things. stickman = Skeleton("Stick man") - stickman.add_nodes(['head', 'neck', 'body', 'right-arm', 'left-arm', 'right-leg', 'left-leg']) - stickman.add_edge('neck', 'head') - stickman.add_edge('body', 'neck') - stickman.add_edge('body', 'right-arm') - stickman.add_edge('body', 'left-arm') - stickman.add_edge('body', 'right-leg') - stickman.add_edge('body', 'left-leg') + stickman.add_nodes( + ["head", "neck", "body", "right-arm", "left-arm", "right-leg", "left-leg"] + ) + stickman.add_edge("neck", "head") + stickman.add_edge("body", "neck") + stickman.add_edge("body", "right-arm") + stickman.add_edge("body", "left-arm") + stickman.add_edge("body", "right-leg") + stickman.add_edge("body", "left-leg") stickman.add_symmetry(node1="left-arm", node2="right-arm") stickman.add_symmetry(node1="left-leg", node2="right-leg") return stickman + @pytest.fixture def skeleton(): @@ -36,4 +40,3 @@ def skeleton(): skeleton.add_symmetry(node1="left-wing", node2="right-wing") return skeleton - diff --git a/tests/fixtures/videos.py b/tests/fixtures/videos.py index ea4369790..fc55e6019 100644 --- a/tests/fixtures/videos.py +++ b/tests/fixtures/videos.py @@ -8,26 +8,42 @@ TEST_H5_AFFINITY = "/pafs" TEST_H5_INPUT_FORMAT = "channels_first" + @pytest.fixture def hdf5_vid(): - return Video.from_hdf5(filename=TEST_H5_FILE, dataset=TEST_H5_DSET, input_format=TEST_H5_INPUT_FORMAT) + return Video.from_hdf5( + filename=TEST_H5_FILE, dataset=TEST_H5_DSET, input_format=TEST_H5_INPUT_FORMAT + ) + @pytest.fixture def hdf5_confmaps(): - return Video.from_hdf5(filename=TEST_H5_FILE, dataset=TEST_H5_CONFMAPS, input_format=TEST_H5_INPUT_FORMAT) - + return Video.from_hdf5( + filename=TEST_H5_FILE, + dataset=TEST_H5_CONFMAPS, + input_format=TEST_H5_INPUT_FORMAT, + ) + + @pytest.fixture def hdf5_affinity(): - return Video.from_hdf5(filename=TEST_H5_FILE, dataset=TEST_H5_AFFINITY, input_format=TEST_H5_INPUT_FORMAT, convert_range=False) + return Video.from_hdf5( + filename=TEST_H5_FILE, + dataset=TEST_H5_AFFINITY, + input_format=TEST_H5_INPUT_FORMAT, + convert_range=False, + ) TEST_SMALL_ROBOT_MP4_FILE = "tests/data/videos/small_robot.mp4" TEST_SMALL_CENTERED_PAIR_VID = "tests/data/videos/centered_pair_small.mp4" + @pytest.fixture def small_robot_mp4_vid(): return Video.from_media(TEST_SMALL_ROBOT_MP4_FILE) + @pytest.fixture def centered_pair_vid(): - return Video.from_media(TEST_SMALL_CENTERED_PAIR_VID) \ No newline at end of file + return Video.from_media(TEST_SMALL_CENTERED_PAIR_VID) diff --git a/tests/gui/test_active.py b/tests/gui/test_active.py new file mode 100644 index 000000000..b3c565e4f --- /dev/null +++ b/tests/gui/test_active.py @@ -0,0 +1,142 @@ +import os + +from sleap.skeleton import Skeleton +from sleap.instance import Instance, Point, LabeledFrame, PredictedInstance +from sleap.io.video import Video +from sleap.io.dataset import Labels +from sleap.nn.model import ModelOutputType +from sleap.gui.active import ( + ActiveLearningDialog, + make_default_training_jobs, + find_saved_jobs, + add_frames_from_json, +) + + +def test_active_gui(qtbot, centered_pair_labels): + win = ActiveLearningDialog( + labels_filename="foo.json", labels=centered_pair_labels, mode="expert" + ) + win.show() + qtbot.addWidget(win) + + # Make sure we include pafs by default + jobs = win._get_current_training_jobs() + assert ModelOutputType.PART_AFFINITY_FIELD in jobs + + # Test option to not include pafs + assert "_dont_use_pafs" in win.form_widget.fields + win.form_widget.set_form_data(dict(_dont_use_pafs=True)) + jobs = win._get_current_training_jobs() + assert ModelOutputType.PART_AFFINITY_FIELD not in jobs + + +def test_make_default_training_jobs(): + jobs = make_default_training_jobs() + + assert ModelOutputType.CONFIDENCE_MAP in jobs + assert ModelOutputType.PART_AFFINITY_FIELD in jobs + + for output_type in jobs: + assert jobs[output_type].model.output_type == output_type + assert jobs[output_type].best_model_filename is None + + +def test_find_saved_jobs(): + jobs_a = find_saved_jobs("tests/data/training_profiles/set_a") + assert len(jobs_a) == 3 + assert len(jobs_a[ModelOutputType.CONFIDENCE_MAP]) == 1 + + jobs_b = find_saved_jobs("tests/data/training_profiles/set_b") + assert len(jobs_b) == 1 + + path, job = jobs_b[ModelOutputType.CONFIDENCE_MAP][0] + assert os.path.basename(path) == "test_confmaps.json" + assert job.trainer.num_epochs == 17 + + # Add jobs from set_a to already loaded jobs from set_b + jobs_c = find_saved_jobs("tests/data/training_profiles/set_a", jobs_b) + assert len(jobs_c) == 3 + + # Make sure we now have two confmap jobs + assert len(jobs_c[ModelOutputType.CONFIDENCE_MAP]) == 2 + + # Make sure set_a was added after items from set_b + paths = [name for (name, job) in jobs_c[ModelOutputType.CONFIDENCE_MAP]] + assert os.path.basename(paths[0]) == "test_confmaps.json" + assert os.path.basename(paths[1]) == "default_confmaps.json" + + +def test_add_frames_from_json(): + vid_a = Video.from_filename("foo.mp4") + vid_b = Video.from_filename("bar.mp4") + + skeleton_a = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") + skeleton_b = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") + + lf_a = LabeledFrame(vid_a, frame_idx=2, instances=[Instance(skeleton_a)]) + lf_b = LabeledFrame(vid_b, frame_idx=3, instances=[Instance(skeleton_b)]) + + empty_labels = Labels() + labels_with_video = Labels(videos=[vid_a]) + labels_with_skeleton = Labels(skeletons=[skeleton_a]) + + new_labels_a = Labels(labeled_frames=[lf_a]) + new_labels_b = Labels(labeled_frames=[lf_b]) + + json_a = new_labels_a.to_dict() + json_b = new_labels_b.to_dict() + + # Test with empty labels + + assert len(empty_labels.labeled_frames) == 0 + assert len(empty_labels.skeletons) == 0 + assert len(empty_labels.skeletons) == 0 + + add_frames_from_json(empty_labels, json_a) + assert len(empty_labels.labeled_frames) == 1 + assert len(empty_labels.videos) == 1 + assert len(empty_labels.skeletons) == 1 + + add_frames_from_json(empty_labels, json_b) + assert len(empty_labels.labeled_frames) == 2 + assert len(empty_labels.videos) == 2 + assert len(empty_labels.skeletons) == 1 + + empty_labels.to_dict() + + # Test with labels that have video + + assert len(labels_with_video.labeled_frames) == 0 + assert len(labels_with_video.skeletons) == 0 + assert len(labels_with_video.videos) == 1 + + add_frames_from_json(labels_with_video, json_a) + assert len(labels_with_video.labeled_frames) == 1 + assert len(labels_with_video.videos) == 1 + assert len(labels_with_video.skeletons) == 1 + + add_frames_from_json(labels_with_video, json_b) + assert len(labels_with_video.labeled_frames) == 2 + assert len(labels_with_video.videos) == 2 + assert len(labels_with_video.skeletons) == 1 + + labels_with_video.to_dict() + + # Test with labels that have skeleton + + assert len(labels_with_skeleton.labeled_frames) == 0 + assert len(labels_with_skeleton.skeletons) == 1 + assert len(labels_with_skeleton.videos) == 0 + + add_frames_from_json(labels_with_skeleton, json_a) + assert len(labels_with_skeleton.labeled_frames) == 1 + assert len(labels_with_skeleton.videos) == 1 + assert len(labels_with_skeleton.skeletons) == 1 + + add_frames_from_json(labels_with_skeleton, json_b) + assert len(labels_with_skeleton.labeled_frames) == 2 + assert len(labels_with_skeleton.videos) == 2 + assert len(labels_with_skeleton.skeletons) == 1 + + labels_with_skeleton.to_dict() diff --git a/tests/gui/test_conf_maps_view.py b/tests/gui/test_conf_maps_view.py index eafb56497..5a97276a8 100644 --- a/tests/gui/test_conf_maps_view.py +++ b/tests/gui/test_conf_maps_view.py @@ -5,13 +5,14 @@ import PySide2.QtCore as QtCore + def test_gui_conf_maps(qtbot, hdf5_confmaps): - + vp = QtVideoPlayer() vp.show() conf_maps = ConfMapsPlot(hdf5_confmaps.get_frame(1), show_box=False) vp.view.scene.addItem(conf_maps) - + # make sure we're showing all the channels assert len(conf_maps.childItems()) == 6 diff --git a/tests/gui/test_dataviews.py b/tests/gui/test_dataviews.py index 3a7541681..9af6dedb8 100644 --- a/tests/gui/test_dataviews.py +++ b/tests/gui/test_dataviews.py @@ -8,8 +8,8 @@ SkeletonNodesTable, SkeletonEdgesTable, LabeledFrameTable, - SkeletonNodeModel - ) + SkeletonNodeModel, +) def test_skeleton_nodes(qtbot, centered_pair_predictions): @@ -24,8 +24,13 @@ def test_skeleton_nodes(qtbot, centered_pair_predictions): table = VideosTable(centered_pair_predictions.videos) table.selectRow(0) - assert table.model().data(table.currentIndex()).find("centered_pair_low_quality.mp4") > -1 + assert ( + table.model().data(table.currentIndex()).find("centered_pair_low_quality.mp4") + > -1 + ) - table = LabeledFrameTable(centered_pair_predictions.labels[13], centered_pair_predictions) + table = LabeledFrameTable( + centered_pair_predictions.labels[13], centered_pair_predictions + ) table.selectRow(1) assert table.model().data(table.currentIndex()) == "21/24" diff --git a/tests/gui/test_import.py b/tests/gui/test_import.py index 1ed473338..760d1839e 100644 --- a/tests/gui/test_import.py +++ b/tests/gui/test_import.py @@ -2,33 +2,47 @@ import PySide2.QtCore as QtCore + def test_gui_import(qtbot): file_names = [ - "tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5", - "tests/data/videos/small_robot.mp4", - ] + "tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5", + "tests/data/videos/small_robot.mp4", + ] importer = ImportParamDialog(file_names) importer.show() qtbot.addWidget(importer) - + data = importer.get_data() assert len(data) == 2 assert len(data[0]["params"]) > 1 - + for import_item in importer.import_widgets: btn = import_item.enabled_checkbox_widget with qtbot.waitSignal(btn.stateChanged, timeout=10): qtbot.mouseClick(btn, QtCore.Qt.LeftButton) assert not import_item.is_enabled() - + assert len(importer.get_data()) == 0 - + for import_item in importer.import_widgets: btn = import_item.enabled_checkbox_widget with qtbot.waitSignal(btn.stateChanged, timeout=10): qtbot.mouseClick(btn, QtCore.Qt.LeftButton) assert import_item.is_enabled() - - assert len(importer.get_data()) == 2 \ No newline at end of file + + assert len(importer.get_data()) == 2 + + +def test_video_import_detect_params(): + importer = ImportParamDialog( + [ + "tests/data/videos/centered_pair_small.mp4", + "tests/data/videos/small_robot.mp4", + ] + ) + data = importer.get_data() + + assert data[0]["params"]["grayscale"] == True + assert data[1]["params"]["grayscale"] == False diff --git a/tests/gui/test_merge.py b/tests/gui/test_merge.py new file mode 100644 index 000000000..6f75753a2 --- /dev/null +++ b/tests/gui/test_merge.py @@ -0,0 +1,5 @@ +from sleap.gui.merge import show_instance_type_counts + + +def test_count_string(simple_predictions): + assert show_instance_type_counts(simple_predictions[0]) == "0/2" diff --git a/tests/gui/test_monitor.py b/tests/gui/test_monitor.py new file mode 100644 index 000000000..d678897e8 --- /dev/null +++ b/tests/gui/test_monitor.py @@ -0,0 +1,12 @@ +from sleap.nn.monitor import LossViewer + + +def test_monitor_release(qtbot): + win = LossViewer() + win.show() + win.close() + + # Make sure the first monitor released its zmq socket + win2 = LossViewer() + win2.show() + win2.close() diff --git a/tests/gui/test_multicheck.py b/tests/gui/test_multicheck.py index cc936d2de..2a2ee0bbe 100644 --- a/tests/gui/test_multicheck.py +++ b/tests/gui/test_multicheck.py @@ -2,23 +2,24 @@ import PySide2.QtCore as QtCore + def test_gui_video(qtbot): cs = MultiCheckWidget(count=10, title="Test", default=True) cs.show() qtbot.addWidget(cs) - + assert cs.getSelected() == list(range(10)) - + for btn in cs.check_group.buttons(): # click all the odd buttons to uncheck them if cs.check_group.id(btn) % 2 == 1: - qtbot.mouseClick(btn, QtCore.Qt.LeftButton) - assert cs.getSelected() == list(range(0,10,2)) - - cs.setSelected([1,2,3]) - assert cs.getSelected() == [1,2,3] - + qtbot.mouseClick(btn, QtCore.Qt.LeftButton) + assert cs.getSelected() == list(range(0, 10, 2)) + + cs.setSelected([1, 2, 3]) + assert cs.getSelected() == [1, 2, 3] + # Watch for the app.worker.finished signal, then start the worker. with qtbot.waitSignal(cs.selectionChanged, timeout=10): qtbot.mouseClick(cs.check_group.buttons()[0], QtCore.Qt.LeftButton) diff --git a/tests/gui/test_quiver.py b/tests/gui/test_quiver.py index a1b877c14..6875cbd2d 100644 --- a/tests/gui/test_quiver.py +++ b/tests/gui/test_quiver.py @@ -5,17 +5,16 @@ import PySide2.QtCore as QtCore + def test_gui_quiver(qtbot, hdf5_affinity): - + vp = QtVideoPlayer() vp.show() affinity_fields = MultiQuiverPlot( - frame=hdf5_affinity.get_frame(0)[265:275,238:248], - show=[0,1], - decimation=1 - ) + frame=hdf5_affinity.get_frame(0)[265:275, 238:248], show=[0, 1], decimation=1 + ) vp.view.scene.addItem(affinity_fields) - + # make sure we're showing all the channels we selected assert len(affinity_fields.childItems()) == 2 # make sure we're showing all arrows in first channel diff --git a/tests/gui/test_shortcuts.py b/tests/gui/test_shortcuts.py new file mode 100644 index 000000000..67c900bca --- /dev/null +++ b/tests/gui/test_shortcuts.py @@ -0,0 +1,13 @@ +from PySide2.QtGui import QKeySequence + +from sleap.gui.shortcuts import Shortcuts + + +def test_shortcuts(): + shortcuts = Shortcuts() + + assert shortcuts["new"] == shortcuts[0] + assert shortcuts["new"] == QKeySequence.fromString("Ctrl+N") + shortcuts["new"] = QKeySequence.fromString("Ctrl+Shift+N") + assert shortcuts["new"] == QKeySequence.fromString("Ctrl+Shift+N") + assert list(shortcuts[0:2].keys()) == ["new", "open"] diff --git a/tests/gui/test_slider.py b/tests/gui/test_slider.py index f69164f3b..0d05b057b 100644 --- a/tests/gui/test_slider.py +++ b/tests/gui/test_slider.py @@ -1,26 +1,27 @@ from sleap.gui.slider import VideoSlider + def test_slider(qtbot, centered_pair_predictions): - + labels = centered_pair_predictions - - slider = VideoSlider(min=0, max=1200, val=15, marks=(10,15)) - + + slider = VideoSlider(min=0, max=1200, val=15, marks=(10, 15)) + assert slider.value() == 15 slider.setValue(20) assert slider.value() == 20 - + assert slider.getSelection() == (0, 0) slider.startSelection(3) slider.endSelection(5) assert slider.getSelection() == (3, 5) slider.clearSelection() assert slider.getSelection() == (0, 0) - + initial_height = slider.maximumHeight() slider.setTracks(20) assert slider.maximumHeight() != initial_height - + slider.setTracksFromLabels(labels, labels.videos[0]) assert len(slider.getMarks()) == 40 diff --git a/tests/gui/test_tracks.py b/tests/gui/test_tracks.py index 4c4481931..5cc00cbc2 100644 --- a/tests/gui/test_tracks.py +++ b/tests/gui/test_tracks.py @@ -1,37 +1,39 @@ from sleap.gui.overlays.tracks import TrackColorManager, TrackTrailOverlay from sleap.io.video import Video + def test_track_trails(centered_pair_predictions): - + labels = centered_pair_predictions - trail_manager = TrackTrailOverlay(labels, scene=None, trail_length = 6) - + trail_manager = TrackTrailOverlay(labels, player=None, trail_length=6) + frames = trail_manager.get_frame_selection(labels.videos[0], 27) assert len(frames) == 6 assert frames[0].frame_idx == 22 - + tracks = trail_manager.get_tracks_in_frame(labels.videos[0], 27) assert len(tracks) == 2 assert tracks[0].name == "1" assert tracks[1].name == "2" trails = trail_manager.get_track_trails(frames, tracks[0]) - + assert len(trails) == 24 - - test_trail = [(245.0, 208.0), + + test_trail = [ + (245.0, 208.0), (245.0, 207.0), (245.0, 206.0), (246.0, 205.0), (247.0, 203.0), - (248.0, 202.0) - ] + (248.0, 202.0), + ] assert test_trail in trails - + # Test track colors color_manager = TrackColorManager(labels=labels) tracks = trail_manager.get_tracks_in_frame(labels.videos[0], 1099) assert len(tracks) == 5 - assert color_manager.get_color(tracks[3]) == [119, 172, 48] \ No newline at end of file + assert color_manager.get_color(tracks[3]) == [119, 172, 48] diff --git a/tests/gui/test_video_player.py b/tests/gui/test_video_player.py index 7a227fce9..ff981a5dd 100644 --- a/tests/gui/test_video_player.py +++ b/tests/gui/test_video_player.py @@ -2,6 +2,7 @@ import PySide2.QtCore as QtCore + def test_gui_video(qtbot): vp = QtVideoPlayer() vp.show() @@ -13,16 +14,17 @@ def test_gui_video(qtbot): # for i in range(20): # qtbot.mouseClick(vp.btn, QtCore.Qt.LeftButton) + def test_gui_video_instances(qtbot, small_robot_mp4_vid, centered_pair_labels): vp = QtVideoPlayer(small_robot_mp4_vid) qtbot.addWidget(vp) - test_frame_idx = 0 - labeled_frames = [_ for _ in centered_pair_labels if _.frame_idx == test_frame_idx] + test_frame_idx = 63 + labeled_frames = centered_pair_labels.labeled_frames def plot_instances(vp, idx): - for instance in labeled_frames[idx].instances: - vp.addInstance(instance=instance, color=(0,0,128)) + for instance in labeled_frames[test_frame_idx].instances: + vp.addInstance(instance=instance, color=(0, 0, 128)) vp.changedPlot.connect(plot_instances) vp.view.updatedViewer.emit() @@ -31,15 +33,15 @@ def plot_instances(vp, idx): vp.plot() # Check that all instances are included in viewer - assert len(vp.instances) == len(labeled_frames[0].instances) + assert len(vp.instances) == len(labeled_frames[test_frame_idx].instances) vp.zoomToFit() # Check that we zoomed correctly - assert(vp.view.zoomFactor > 2) - + assert vp.view.zoomFactor > 1 + vp.instances[0].updatePoints(complete=True) - + # Check that node is marked as complete assert vp.instances[0].childItems()[3].point.complete @@ -50,6 +52,11 @@ def plot_instances(vp, idx): qtbot.keyClick(vp, QtCore.Qt.Key_QuoteLeft) assert vp.view.getSelection() == 1 + # Check that selection by Instance works + for inst in labeled_frames[test_frame_idx].instances: + vp.view.selectInstance(inst) + assert vp.view.getSelectionInstance() == inst + # Check that sequence selection works with qtbot.waitCallback() as cb: vp.view.clearSelection() @@ -58,4 +65,4 @@ def plot_instances(vp, idx): qtbot.keyClick(vp, QtCore.Qt.Key_1) assert cb.args[0] == [1, 0] - assert vp.close() \ No newline at end of file + assert vp.close() diff --git a/tests/info/test_h5.py b/tests/info/test_h5.py new file mode 100644 index 000000000..e93e1bc7b --- /dev/null +++ b/tests/info/test_h5.py @@ -0,0 +1,89 @@ +import os + +import h5py +import numpy as np + +from sleap.info.write_tracking_h5 import ( + get_tracks_as_np_strings, + get_occupancy_and_points_matrices, + remove_empty_tracks_from_matrices, + write_occupancy_file, +) + + +def test_output_matrices(centered_pair_predictions): + + names = get_tracks_as_np_strings(centered_pair_predictions) + assert len(names) == 27 + assert isinstance(names[0], np.string_) + + # Remove the first labeled frame + del centered_pair_predictions[0] + assert len(centered_pair_predictions) == 1099 + + occupancy, points = get_occupancy_and_points_matrices( + centered_pair_predictions, all_frames=False + ) + + assert occupancy.shape == (27, 1099) + assert points.shape == (1099, 24, 2, 27) + + # Make sure "all_frames" includes the missing initial frame + occupancy, points = get_occupancy_and_points_matrices( + centered_pair_predictions, all_frames=True + ) + + assert occupancy.shape == (27, 1100) + assert points.shape == (1100, 24, 2, 27) + + # Make sure removing empty tracks doesn't yet change anything + names, occupancy, points = remove_empty_tracks_from_matrices( + names, occupancy, points + ) + + assert len(names) == 27 + assert occupancy.shape == (27, 1100) + assert points.shape == (1100, 24, 2, 27) + + # Remove all instances from track 13 + vid = centered_pair_predictions.videos[0] + track = centered_pair_predictions.tracks[13] + instances = centered_pair_predictions.find_track_occupancy(vid, track) + for instance in instances: + centered_pair_predictions.remove_instance(instance.frame, instance) + + # Make sure that this now remove empty track + occupancy, points = get_occupancy_and_points_matrices( + centered_pair_predictions, all_frames=True + ) + names, occupancy, points = remove_empty_tracks_from_matrices( + names, occupancy, points + ) + + assert len(names) == 26 + assert occupancy.shape == (26, 1100) + assert points.shape == (1100, 24, 2, 26) + + +def test_hdf5_saving(tmpdir): + path = os.path.join(tmpdir, "occupany.h5") + + x = np.array([[1, 2, 6], [3, 4, 5]]) + data_dict = dict(x=x) + + write_occupancy_file(path, data_dict, transpose=False) + + with h5py.File(path, "r") as f: + assert f["x"].shape == x.shape + + +def test_hdf5_tranposed_saving(tmpdir): + path = os.path.join(tmpdir, "transposed.h5") + + x = np.array([[1, 2, 6], [3, 4, 5]]) + data_dict = dict(x=x) + + write_occupancy_file(path, data_dict, transpose=True) + + with h5py.File(path, "r") as f: + assert f["x"].shape == np.transpose(x).shape diff --git a/tests/info/test_summary.py b/tests/info/test_summary.py new file mode 100644 index 000000000..2cf76c166 --- /dev/null +++ b/tests/info/test_summary.py @@ -0,0 +1,42 @@ +from sleap.info.summary import StatisticSeries + + +def test_frame_statistics(simple_predictions): + video = simple_predictions.videos[0] + stats = StatisticSeries(simple_predictions) + + x = stats.get_point_count_series(video) + assert len(x) == 2 + assert x[0] == 4 + assert x[1] == 4 + + x = stats.get_point_score_series(video, "sum") + assert len(x) == 2 + assert x[0] == 2.4 + assert x[1] == 5.2 + + x = stats.get_point_score_series(video, "min") + assert len(x) == 2 + assert x[0] == 0.5 + assert x[1] == 1.0 + + x = stats.get_instance_score_series(video, "sum") + assert len(x) == 2 + assert x[0] == 7 + assert x[1] == 9 + + x = stats.get_instance_score_series(video, "min") + assert len(x) == 2 + assert x[0] == 2 + assert x[1] == 3 + + x = stats.get_point_displacement_series(video, "mean") + assert len(x) == 2 + assert x[0] == 0 + assert x[1] == 9.0 + + x = stats.get_point_displacement_series(video, "max") + assert len(x) == 2 + assert len(x) == 2 + assert x[0] == 0 + assert x[1] == 18.0 diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 70f2d37ff..f141ae6fc 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -5,13 +5,16 @@ from sleap.skeleton import Skeleton from sleap.instance import Instance, Point, LabeledFrame, PredictedInstance from sleap.io.video import Video, MediaVideo -from sleap.io.dataset import Labels, load_labels_json_old +from sleap.io.dataset import Labels +from sleap.io.legacy import load_labels_json_old +from sleap.gui.suggestions import VideoFrameSuggestions -TEST_H5_DATASET = 'tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5' +TEST_H5_DATASET = "tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5" -def _check_labels_match(expected_labels, other_labels, format = 'png'): + +def _check_labels_match(expected_labels, other_labels, format="png"): """ - A utitlity function to check whether to sets of labels match. + A utility function to check whether to sets of labels match. This doesn't directly compares some things (like video objects). Args: @@ -42,7 +45,10 @@ def dict_match(dict1, dict2): # Check if the graphs are iso-morphic import networkx as nx - is_isomorphic = nx.is_isomorphic(self._graph, other._graph, node_match=dict_match) + + is_isomorphic = nx.is_isomorphic( + self._graph, other._graph, node_match=dict_match + ) if not is_isomorphic: assert False @@ -68,11 +74,14 @@ def dict_match(dict1, dict2): # Compare the first frames of the videos, do it on a small sub-region to # make the test reasonable in time. - if format is 'png': + if format is "png": assert np.allclose(frame_data, expected_frame_data) # Compare the instances - assert all(i1.matches(i2) for (i1, i2) in zip(expected_label.instances, label.instances)) + assert all( + i1.matches(i2) + for (i1, i2) in zip(expected_label.instances, label.instances) + ) # This test takes to long, break after 20 or so. if frame_idx > 20: @@ -80,7 +89,7 @@ def dict_match(dict1, dict2): def test_labels_json(tmpdir, multi_skel_vid_labels): - json_file_path = os.path.join(tmpdir, 'dataset.json') + json_file_path = os.path.join(tmpdir, "dataset.json") if os.path.isfile(json_file_path): os.remove(json_file_path) @@ -111,17 +120,38 @@ def test_labels_json(tmpdir, multi_skel_vid_labels): assert multi_skel_vid_labels.nodes[3] in loaded_labels.nodes assert multi_skel_vid_labels.videos[0] in loaded_labels.videos + def test_load_labels_json_old(tmpdir): - new_file_path = os.path.join(tmpdir, 'centered_pair_v2.json') + new_file_path = os.path.join(tmpdir, "centered_pair_v2.json") # Function to run some checks on loaded labels def check_labels(labels): - skel_node_names = ['head', 'neck', 'thorax', 'abdomen', 'wingL', - 'wingR', 'forelegL1', 'forelegL2', 'forelegL3', - 'forelegR1', 'forelegR2', 'forelegR3', 'midlegL1', - 'midlegL2', 'midlegL3', 'midlegR1', 'midlegR2', - 'midlegR3', 'hindlegL1', 'hindlegL2', 'hindlegL3', - 'hindlegR1', 'hindlegR2', 'hindlegR3'] + skel_node_names = [ + "head", + "neck", + "thorax", + "abdomen", + "wingL", + "wingR", + "forelegL1", + "forelegL2", + "forelegL3", + "forelegR1", + "forelegR2", + "forelegR3", + "midlegL1", + "midlegL2", + "midlegL3", + "midlegR1", + "midlegR2", + "midlegR3", + "hindlegL1", + "hindlegL2", + "hindlegL3", + "hindlegR1", + "hindlegR2", + "hindlegR3", + ] # Do some basic checks assert len(labels) == 70 @@ -155,6 +185,23 @@ def test_label_accessors(centered_pair_labels): assert len(labels.find(video)) == 70 assert labels[video] == labels.find(video) + f = labels.frames(video, from_frame_idx=1) + assert next(f).frame_idx == 15 + assert next(f).frame_idx == 31 + + f = labels.frames(video, from_frame_idx=31, reverse=True) + assert next(f).frame_idx == 15 + + f = labels.frames(video, from_frame_idx=0, reverse=True) + assert next(f).frame_idx == 1092 + next(f) + next(f) + # test that iterator now has fewer items left + assert len(list(f)) == 70 - 3 + + assert labels.instance_count(video, 15) == 2 + assert labels.instance_count(video, 7) == 0 + assert labels[0].video == video assert labels[0].frame_idx == 0 @@ -166,6 +213,7 @@ def test_label_accessors(centered_pair_labels): assert labels.find(video, 954)[0] == labels[61] assert labels.find_first(video) == labels[0] assert labels.find_first(video, 954) == labels[61] + assert labels.find_last(video) == labels[69] assert labels[video, 954] == labels[61] assert labels[video, 0] == labels[0] assert labels[video] == labels.labels @@ -185,7 +233,7 @@ def test_label_mutability(): dummy_video = Video(backend=MediaVideo) dummy_skeleton = Skeleton() dummy_instance = Instance(dummy_skeleton) - dummy_frame = LabeledFrame(dummy_video, frame_idx=0, instances=[dummy_instance,]) + dummy_frame = LabeledFrame(dummy_video, frame_idx=0, instances=[dummy_instance]) labels = Labels() labels.append(dummy_frame) @@ -202,7 +250,7 @@ def test_label_mutability(): dummy_video2 = Video(backend=MediaVideo) dummy_skeleton2 = Skeleton(name="dummy2") dummy_instance2 = Instance(dummy_skeleton2) - dummy_frame2 = LabeledFrame(dummy_video2, frame_idx=0, instances=[dummy_instance2,]) + dummy_frame2 = LabeledFrame(dummy_video2, frame_idx=0, instances=[dummy_instance2]) assert dummy_video2 not in labels assert dummy_skeleton2 not in labels assert dummy_frame2 not in labels @@ -226,9 +274,9 @@ def test_label_mutability(): for f in dummy_frames + dummy_frames2: labels.append(f) - assert(len(labels) == 20) + assert len(labels) == 20 labels.remove_video(dummy_video2) - assert(len(labels) == 10) + assert len(labels) == 10 assert len(labels.find(dummy_video)) == 10 assert dummy_frame in labels @@ -241,15 +289,224 @@ def test_label_mutability(): labels.remove_video(dummy_video) assert len(labels.find(dummy_video)) == 0 - dummy_frames3 = [LabeledFrame(dummy_video, frame_idx=0, instances=[dummy_instance,]) for _ in range(10)] - labels.labeled_frames.extend(dummy_frames3) + +def test_labels_merge(): + dummy_video = Video(backend=MediaVideo) + dummy_skeleton = Skeleton() + dummy_skeleton.add_node("node") + + labels = Labels() + dummy_frames = [] + + # Add 10 instances with different points (so they aren't "redundant") + for i in range(10): + instance = Instance(skeleton=dummy_skeleton, points=dict(node=Point(i, i))) + dummy_frame = LabeledFrame(dummy_video, frame_idx=0, instances=[instance]) + dummy_frames.append(dummy_frame) + + labels.labeled_frames.extend(dummy_frames) assert len(labels) == 10 assert len(labels.labeled_frames[0].instances) == 1 + labels.merge_matching_frames() assert len(labels) == 1 assert len(labels.labeled_frames[0].instances) == 10 +def test_complex_merge(): + dummy_video_a = Video.from_filename("foo.mp4") + dummy_video_b = Video.from_filename("foo.mp4") + + dummy_skeleton_a = Skeleton() + dummy_skeleton_a.add_node("node") + + dummy_skeleton_b = Skeleton() + dummy_skeleton_b.add_node("node") + + dummy_instances_a = [] + dummy_instances_a.append( + Instance(skeleton=dummy_skeleton_a, points=dict(node=Point(1, 1))) + ) + dummy_instances_a.append( + Instance(skeleton=dummy_skeleton_a, points=dict(node=Point(2, 2))) + ) + + labels_a = Labels() + labels_a.append( + LabeledFrame(dummy_video_a, frame_idx=0, instances=dummy_instances_a) + ) + + dummy_instances_b = [] + dummy_instances_b.append( + Instance(skeleton=dummy_skeleton_b, points=dict(node=Point(1, 1))) + ) + dummy_instances_b.append( + Instance(skeleton=dummy_skeleton_b, points=dict(node=Point(3, 3))) + ) + + labels_b = Labels() + labels_b.append( + LabeledFrame(dummy_video_b, frame_idx=0, instances=dummy_instances_b) + ) # conflict + labels_b.append( + LabeledFrame(dummy_video_b, frame_idx=1, instances=dummy_instances_b) + ) # clean + + merged, extra_a, extra_b = Labels.complex_merge_between(labels_a, labels_b) + + # Check that we have the cleanly merged frame + assert dummy_video_a in merged + assert len(merged[dummy_video_a]) == 1 # one merged frame + assert len(merged[dummy_video_a][1]) == 2 # with two instances + + # Check that labels_a includes redundant and clean + assert len(labels_a.labeled_frames) == 2 + assert len(labels_a.labeled_frames[0].instances) == 1 + assert labels_a.labeled_frames[0].instances[0].points[0].x == 1 + assert len(labels_a.labeled_frames[1].instances) == 2 + assert labels_a.labeled_frames[1].instances[0].points[0].x == 1 + assert labels_a.labeled_frames[1].instances[1].points[0].x == 3 + + # Check that extra_a/b includes the appropriate conflicting instance + assert len(extra_a) == 1 + assert len(extra_b) == 1 + assert len(extra_a[0].instances) == 1 + assert len(extra_b[0].instances) == 1 + assert extra_a[0].instances[0].points[0].x == 2 + assert extra_b[0].instances[0].points[0].x == 3 + + # Check that objects were unified + assert extra_a[0].video == extra_b[0].video + + # Check resolving the conflict using new + Labels.finish_complex_merge(labels_a, extra_b) + assert len(labels_a.labeled_frames) == 2 + assert len(labels_a.labeled_frames[0].instances) == 2 + assert labels_a.labeled_frames[0].instances[1].points[0].x == 3 + + +def test_merge_predictions(): + dummy_video_a = Video.from_filename("foo.mp4") + dummy_video_b = Video.from_filename("foo.mp4") + + dummy_skeleton_a = Skeleton() + dummy_skeleton_a.add_node("node") + + dummy_skeleton_b = Skeleton() + dummy_skeleton_b.add_node("node") + + dummy_instances_a = [] + dummy_instances_a.append( + Instance(skeleton=dummy_skeleton_a, points=dict(node=Point(1, 1))) + ) + dummy_instances_a.append( + Instance(skeleton=dummy_skeleton_a, points=dict(node=Point(2, 2))) + ) + + labels_a = Labels() + labels_a.append( + LabeledFrame(dummy_video_a, frame_idx=0, instances=dummy_instances_a) + ) + + dummy_instances_b = [] + dummy_instances_b.append( + Instance(skeleton=dummy_skeleton_b, points=dict(node=Point(1, 1))) + ) + dummy_instances_b.append( + PredictedInstance( + skeleton=dummy_skeleton_b, points=dict(node=Point(3, 3)), score=1 + ) + ) + + labels_b = Labels() + labels_b.append( + LabeledFrame(dummy_video_b, frame_idx=0, instances=dummy_instances_b) + ) + + # Frames have one redundant instance (perfect match) and all the + # non-matching instances are different types (one predicted, one not). + merged, extra_a, extra_b = Labels.complex_merge_between(labels_a, labels_b) + assert len(merged[dummy_video_a]) == 1 + assert len(merged[dummy_video_a][0]) == 1 # the predicted instance was merged + assert not extra_a + assert not extra_b + + +def skeleton_ids_from_label_instances(labels): + return list(map(id, (lf.instances[0].skeleton for lf in labels.labeled_frames))) + + +def test_duplicate_skeletons_serializing(): + vid = Video.from_filename("foo.mp4") + + skeleton_a = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") + skeleton_b = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") + + lf_a = LabeledFrame(vid, frame_idx=2, instances=[Instance(skeleton_a)]) + lf_b = LabeledFrame(vid, frame_idx=3, instances=[Instance(skeleton_b)]) + + new_labels = Labels(labeled_frames=[lf_a, lf_b]) + new_labels_json = new_labels.to_dict() + + +def test_distinct_skeletons_serializing(): + vid = Video.from_filename("foo.mp4") + + skeleton_a = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") + skeleton_b = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") + skeleton_b.add_node("foo") + + lf_a = LabeledFrame(vid, frame_idx=2, instances=[Instance(skeleton_a)]) + lf_b = LabeledFrame(vid, frame_idx=3, instances=[Instance(skeleton_b)]) + + new_labels = Labels(labeled_frames=[lf_a, lf_b]) + + # Make sure we can serialize this + new_labels_json = new_labels.to_dict() + + +def test_unify_skeletons(): + vid = Video.from_filename("foo.mp4") + + skeleton_a = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") + skeleton_b = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") + + lf_a = LabeledFrame(vid, frame_idx=2, instances=[Instance(skeleton_a)]) + lf_b = LabeledFrame(vid, frame_idx=3, instances=[Instance(skeleton_b)]) + + labels = Labels() + labels.extend_from([lf_a], unify=True) + labels.extend_from([lf_b], unify=True) + ids = skeleton_ids_from_label_instances(labels) + + # Make sure that skeleton_b got replaced with skeleton_a when we + # added the frame with "unify" set + assert len(set(ids)) == 1 + + # Make sure we can serialize this + labels.to_dict() + + +def test_dont_unify_skeletons(): + vid = Video.from_filename("foo.mp4") + + skeleton_a = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") + skeleton_b = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") + + lf_a = LabeledFrame(vid, frame_idx=2, instances=[Instance(skeleton_a)]) + lf_b = LabeledFrame(vid, frame_idx=3, instances=[Instance(skeleton_b)]) + + labels = Labels(labeled_frames=[lf_a]) + labels.extend_from([lf_b], unify=False) + ids = skeleton_ids_from_label_instances(labels) + + # Make sure we still have two distinct skeleton objects + assert len(set(ids)) == 2 + + # Make sure we can serialize this + labels.to_dict() + + def test_instance_access(): labels = Labels() @@ -258,21 +515,69 @@ def test_instance_access(): dummy_video2 = Video(backend=MediaVideo) for i in range(10): - labels.append(LabeledFrame(dummy_video, frame_idx=i, instances=[Instance(dummy_skeleton), Instance(dummy_skeleton)])) + labels.append( + LabeledFrame( + dummy_video, + frame_idx=i, + instances=[Instance(dummy_skeleton), Instance(dummy_skeleton)], + ) + ) for i in range(10): - labels.append(LabeledFrame(dummy_video2, frame_idx=i, instances=[Instance(dummy_skeleton), Instance(dummy_skeleton), Instance(dummy_skeleton)])) + labels.append( + LabeledFrame( + dummy_video2, + frame_idx=i, + instances=[ + Instance(dummy_skeleton), + Instance(dummy_skeleton), + Instance(dummy_skeleton), + ], + ) + ) assert len(labels.all_instances) == 50 assert len(list(labels.instances(video=dummy_video))) == 20 assert len(list(labels.instances(video=dummy_video2))) == 30 +def test_suggestions(small_robot_mp4_vid): + dummy_video = small_robot_mp4_vid + dummy_skeleton = Skeleton() + dummy_instance = Instance(dummy_skeleton) + dummy_frame = LabeledFrame(dummy_video, frame_idx=0, instances=[dummy_instance]) + + labels = Labels() + labels.append(dummy_frame) + + suggestions = dict() + suggestions[dummy_video] = VideoFrameSuggestions.suggest( + dummy_video, params=dict(method="random", per_video=13) + ) + labels.set_suggestions(suggestions) + + assert len(labels.get_video_suggestions(dummy_video)) == 13 + + +def test_negative_anchors(): + video = Video.from_filename("foo.mp4") + labels = Labels() + + labels.add_negative_anchor(video, 1, (3, 4)) + labels.add_negative_anchor(video, 1, (7, 8)) + labels.add_negative_anchor(video, 2, (5, 9)) + + assert len(labels.negative_anchors[video]) == 3 + + labels.remove_negative_anchors(video, 1) + assert len(labels.negative_anchors[video]) == 1 + + def test_load_labels_mat(mat_labels): assert len(mat_labels.nodes) == 6 assert len(mat_labels) == 43 -@pytest.mark.parametrize("format", ['png', 'mjpeg/avi']) +@pytest.mark.parametrize("format", ["png", "mjpeg/avi"]) def test_save_labels_with_frame_data(multi_skel_vid_labels, tmpdir, format): """ Test saving and loading a labels dataset with frame data included @@ -282,8 +587,13 @@ def test_save_labels_with_frame_data(multi_skel_vid_labels, tmpdir, format): # Lets take a subset of the labels so this doesn't take too long multi_skel_vid_labels.labeled_frames = multi_skel_vid_labels.labeled_frames[5:30] - filename = os.path.join(tmpdir, 'test.json') - Labels.save_json(multi_skel_vid_labels, filename=filename, save_frame_data=True, frame_data_format=format) + filename = os.path.join(tmpdir, "test.json") + Labels.save_json( + multi_skel_vid_labels, + filename=filename, + save_frame_data=True, + frame_data_format=format, + ) # Load the data back in loaded_labels = Labels.load_json(f"{filename}.zip") @@ -295,9 +605,32 @@ def test_save_labels_with_frame_data(multi_skel_vid_labels, tmpdir, format): loaded_labels = Labels.load_json(f"{filename}.zip") +def test_save_labels_and_frames_hdf5(multi_skel_vid_labels, tmpdir): + # Lets take a subset of the labels so this doesn't take too long + labels = multi_skel_vid_labels + labels.labeled_frames = labels.labeled_frames[5:30] + + filename = os.path.join(tmpdir, "test.h5") + + Labels.save_hdf5(filename=filename, labels=labels, save_frame_data=True) + + loaded_labels = Labels.load_hdf5(filename=filename) + + _check_labels_match(labels, loaded_labels) + + # Rename file (after closing videos) + for vid in loaded_labels.videos: + vid.close() + filerename = os.path.join(tmpdir, "test_rename.h5") + os.rename(filename, filerename) + + # Make sure we open can after rename + loaded_labels = Labels.load_hdf5(filename=filerename) + + def test_labels_hdf5(multi_skel_vid_labels, tmpdir): labels = multi_skel_vid_labels - filename = os.path.join(tmpdir, 'test.h5') + filename = os.path.join(tmpdir, "test.h5") Labels.save_hdf5(filename=filename, labels=labels) @@ -308,7 +641,7 @@ def test_labels_hdf5(multi_skel_vid_labels, tmpdir): def test_labels_predicted_hdf5(multi_skel_vid_labels, tmpdir): labels = multi_skel_vid_labels - filename = os.path.join(tmpdir, 'test.h5') + filename = os.path.join(tmpdir, "test.h5") # Lets promote some of these Instances to predicted instances for label in labels: @@ -339,9 +672,10 @@ def test_labels_predicted_hdf5(multi_skel_vid_labels, tmpdir): loaded_labels = Labels.load_hdf5(filename=filename) _check_labels_match(labels, loaded_labels) + def test_labels_append_hdf5(multi_skel_vid_labels, tmpdir): labels = multi_skel_vid_labels - filename = os.path.join(tmpdir, 'test.h5') + filename = os.path.join(tmpdir, "test.h5") # Save each frame of the Labels dataset one by one in append # mode @@ -359,3 +693,23 @@ def test_labels_append_hdf5(multi_skel_vid_labels, tmpdir): _check_labels_match(labels, loaded_labels) + +def test_hdf5_from_predicted(multi_skel_vid_labels, tmpdir): + labels = multi_skel_vid_labels + filename = os.path.join(tmpdir, "test.h5") + + # Add some predicted instances to create from_predicted links + for frame_num, frame in enumerate(labels): + if frame_num % 20 == 0: + frame.instances[0].from_predicted = PredictedInstance.from_instance( + frame.instances[0], float(frame_num) + ) + frame.instances.append(frame.instances[0].from_predicted) + + # Save and load, compare the results + Labels.save_hdf5(filename=filename, labels=labels) + loaded_labels = Labels.load_hdf5(filename=filename) + + for frame_num, frame in enumerate(loaded_labels): + if frame_num % 20 == 0: + assert frame.instances[0].from_predicted.score == float(frame_num) diff --git a/tests/io/test_video.py b/tests/io/test_video.py index f694cc19a..f6030ff94 100644 --- a/tests/io/test_video.py +++ b/tests/io/test_video.py @@ -3,7 +3,7 @@ import numpy as np -from sleap.io.video import Video +from sleap.io.video import Video, HDF5Video, MediaVideo from tests.fixtures.videos import TEST_H5_FILE, TEST_SMALL_ROBOT_MP4_FILE # FIXME: @@ -11,79 +11,96 @@ # of redundant test code here. # See: https://github.com/pytest-dev/pytest/issues/349 + +def test_from_filename(): + assert type(Video.from_filename(TEST_H5_FILE).backend) == HDF5Video + assert type(Video.from_filename(TEST_SMALL_ROBOT_MP4_FILE).backend) == MediaVideo + + def test_hdf5_get_shape(hdf5_vid): - assert(hdf5_vid.shape == (42, 512, 512, 1)) + assert hdf5_vid.shape == (42, 512, 512, 1) def test_hdf5_len(hdf5_vid): - assert(len(hdf5_vid) == 42) + assert len(hdf5_vid) == 42 def test_hdf5_dtype(hdf5_vid): - assert(hdf5_vid.dtype == np.uint8) + assert hdf5_vid.dtype == np.uint8 def test_hdf5_get_frame(hdf5_vid): - assert(hdf5_vid.get_frame(0).shape == (512, 512, 1)) + assert hdf5_vid.get_frame(0).shape == (512, 512, 1) def test_hdf5_get_frames(hdf5_vid): - assert(hdf5_vid.get_frames(0).shape == (1, 512, 512, 1)) - assert(hdf5_vid.get_frames([0,1]).shape == (2, 512, 512, 1)) + assert hdf5_vid.get_frames(0).shape == (1, 512, 512, 1) + assert hdf5_vid.get_frames([0, 1]).shape == (2, 512, 512, 1) def test_hdf5_get_item(hdf5_vid): - assert(hdf5_vid[0].shape == (1, 512, 512, 1)) - assert(np.alltrue(hdf5_vid[1:10:3] == hdf5_vid.get_frames([1, 4, 7]))) + assert hdf5_vid[0].shape == (1, 512, 512, 1) + assert np.alltrue(hdf5_vid[1:10:3] == hdf5_vid.get_frames([1, 4, 7])) + def test_hd5f_file_not_found(): with pytest.raises(FileNotFoundError): - Video.from_hdf5("non-existent-filename.h5", 'dataset_name') + Video.from_hdf5("non-existent-filename.h5", "dataset_name") + def test_mp4_get_shape(small_robot_mp4_vid): - assert(small_robot_mp4_vid.shape == (166, 320, 560, 3)) + assert small_robot_mp4_vid.shape == (166, 320, 560, 3) + + +def test_mp4_fps(small_robot_mp4_vid): + assert small_robot_mp4_vid.fps == 30.0 def test_mp4_len(small_robot_mp4_vid): - assert(len(small_robot_mp4_vid) == 166) + assert len(small_robot_mp4_vid) == 166 def test_mp4_dtype(small_robot_mp4_vid): - assert(small_robot_mp4_vid.dtype == np.uint8) + assert small_robot_mp4_vid.dtype == np.uint8 def test_mp4_get_frame(small_robot_mp4_vid): - assert(small_robot_mp4_vid.get_frame(0).shape == (320, 560, 3)) + assert small_robot_mp4_vid.get_frame(0).shape == (320, 560, 3) def test_mp4_get_frames(small_robot_mp4_vid): - assert(small_robot_mp4_vid.get_frames(0).shape == (1, 320, 560, 3)) - assert(small_robot_mp4_vid.get_frames([0,1]).shape == (2, 320, 560, 3)) + assert small_robot_mp4_vid.get_frames(0).shape == (1, 320, 560, 3) + assert small_robot_mp4_vid.get_frames([0, 1]).shape == (2, 320, 560, 3) def test_mp4_get_item(small_robot_mp4_vid): - assert(small_robot_mp4_vid[0].shape == (1, 320, 560, 3)) - assert(np.alltrue(small_robot_mp4_vid[1:10:3] == small_robot_mp4_vid.get_frames([1, 4, 7]))) + assert small_robot_mp4_vid[0].shape == (1, 320, 560, 3) + assert np.alltrue( + small_robot_mp4_vid[1:10:3] == small_robot_mp4_vid.get_frames([1, 4, 7]) + ) + def test_mp4_file_not_found(): with pytest.raises(FileNotFoundError): vid = Video.from_media("non-existent-filename.mp4") vid.channels + def test_numpy_frames(small_robot_mp4_vid): - clip_frames = small_robot_mp4_vid.get_frames((3,7,9)) + clip_frames = small_robot_mp4_vid.get_frames((3, 7, 9)) np_vid = Video.from_numpy(clip_frames) assert np.all(np.equal(np_vid.get_frame(1), small_robot_mp4_vid.get_frame(7))) -@pytest.mark.parametrize("format", ['png', 'jpg', "mjpeg/avi"]) + +@pytest.mark.parametrize("format", ["png", "jpg", "mjpeg/avi"]) def test_imgstore_video(small_robot_mp4_vid, tmpdir, format): - path = os.path.join(tmpdir, 'test_imgstore') + path = os.path.join(tmpdir, "test_imgstore") # If format is video, test saving all the frames. if format == "mjpeg/avi": - frame_indices = None + frame_indices = None else: frame_indices = [0, 1, 5] @@ -91,9 +108,13 @@ def test_imgstore_video(small_robot_mp4_vid, tmpdir, format): # video. if format == "png": # Check that the default format is "png" - imgstore_vid = small_robot_mp4_vid.to_imgstore(path, frame_numbers=frame_indices) + imgstore_vid = small_robot_mp4_vid.to_imgstore( + path, frame_numbers=frame_indices + ) else: - imgstore_vid = small_robot_mp4_vid.to_imgstore(path, frame_numbers=frame_indices, format=format) + imgstore_vid = small_robot_mp4_vid.to_imgstore( + path, frame_numbers=frame_indices, format=format + ) if frame_indices is None: assert small_robot_mp4_vid.num_frames == imgstore_vid.num_frames @@ -103,19 +124,21 @@ def test_imgstore_video(small_robot_mp4_vid, tmpdir, format): assert type(imgstore_vid.get_frame(i)) == np.ndarray else: - assert(imgstore_vid.num_frames == len(frame_indices)) + assert imgstore_vid.num_frames == len(frame_indices) # Make sure we can read arbitrary frames by imgstore frame number for i in frame_indices: assert type(imgstore_vid.get_frame(i)) == np.ndarray - assert(imgstore_vid.channels == 3) - assert(imgstore_vid.height == 320) - assert(imgstore_vid.width == 560) + assert imgstore_vid.channels == 3 + assert imgstore_vid.height == 320 + assert imgstore_vid.width == 560 # Check the image data is exactly the same when lossless is used. if format == "png": - assert np.allclose(imgstore_vid.get_frame(0), small_robot_mp4_vid.get_frame(0), rtol=0.91) + assert np.allclose( + imgstore_vid.get_frame(0), small_robot_mp4_vid.get_frame(0), rtol=0.91 + ) def test_imgstore_indexing(small_robot_mp4_vid, tmpdir): @@ -123,16 +146,20 @@ def test_imgstore_indexing(small_robot_mp4_vid, tmpdir): Test different types of indexing (by frame number or index) supported by only imgstore videos. """ - path = os.path.join(tmpdir, 'test_imgstore') + path = os.path.join(tmpdir, "test_imgstore") frame_indices = [20, 40, 15] - imgstore_vid = small_robot_mp4_vid.to_imgstore(path, frame_numbers=frame_indices, index_by_original=False) + imgstore_vid = small_robot_mp4_vid.to_imgstore( + path, frame_numbers=frame_indices, index_by_original=False + ) # Index by frame index in imgstore frames = imgstore_vid.get_frames([0, 1, 2]) assert frames.shape == (3, 320, 560, 3) + assert imgstore_vid.last_frame_idx == len(frame_indices) - 1 + with pytest.raises(ValueError): imgstore_vid.get_frames(frame_indices) @@ -143,5 +170,97 @@ def test_imgstore_indexing(small_robot_mp4_vid, tmpdir): frames = imgstore_vid.get_frames(frame_indices) assert frames.shape == (3, 320, 560, 3) + assert imgstore_vid.last_frame_idx == max(frame_indices) + with pytest.raises(ValueError): imgstore_vid.get_frames([0, 1, 2]) + + +def test_imgstore_deferred_loading(small_robot_mp4_vid, tmpdir): + path = os.path.join(tmpdir, "test_imgstore") + frame_indices = [20, 40, 15] + vid = small_robot_mp4_vid.to_imgstore(path, frame_numbers=frame_indices) + + # This is actually testing that the __img will be loaded when needed, + # since we use __img to get dtype. + assert vid.dtype == np.dtype("uint8") + + +def test_imgstore_single_channel(centered_pair_vid, tmpdir): + path = os.path.join(tmpdir, "test_imgstore") + frame_indices = [20, 40, 15] + vid = centered_pair_vid.to_imgstore(path, frame_numbers=frame_indices) + + assert vid.channels == 1 + + +def test_empty_hdf5_video(small_robot_mp4_vid, tmpdir): + path = os.path.join(tmpdir, "test_to_hdf5") + hdf5_vid = small_robot_mp4_vid.to_hdf5(path, "testvid", frame_numbers=[]) + + +@pytest.mark.parametrize("format", ["", "png", "jpg"]) +def test_hdf5_inline_video(small_robot_mp4_vid, tmpdir, format): + + path = os.path.join(tmpdir, f"test_to_hdf5_{format}") + frame_indices = [0, 1, 5] + + # Save hdf5 version of the first few frames of this video. + hdf5_vid = small_robot_mp4_vid.to_hdf5( + path, "testvid", format=format, frame_numbers=frame_indices + ) + + assert hdf5_vid.num_frames == len(frame_indices) + + # Make sure we can read arbitrary frames by imgstore frame number + for i in frame_indices: + assert type(hdf5_vid.get_frame(i)) == np.ndarray + + assert hdf5_vid.channels == 3 + assert hdf5_vid.height == 320 + assert hdf5_vid.width == 560 + + # Check the image data is exactly the same when lossless is used. + if format in ("", "png"): + assert np.allclose( + hdf5_vid.get_frame(0), small_robot_mp4_vid.get_frame(0), rtol=0.91 + ) + + +def test_hdf5_indexing(small_robot_mp4_vid, tmpdir): + """ + Test different types of indexing (by frame number or index). + """ + path = os.path.join(tmpdir, "test_to_hdf5") + + frame_indices = [20, 40, 15] + + hdf5_vid = small_robot_mp4_vid.to_hdf5( + path, dataset="testvid2", frame_numbers=frame_indices, index_by_original=False + ) + + # Index by frame index in imgstore + frames = hdf5_vid.get_frames([0, 1, 2]) + assert frames.shape == (3, 320, 560, 3) + + assert hdf5_vid.last_frame_idx == len(frame_indices) - 1 + + with pytest.raises(ValueError): + hdf5_vid.get_frames(frame_indices) + + # We have to close file before we can add another video dataset. + hdf5_vid.close() + + # Now re-create the imgstore with frame number indexing, (the default) + hdf5_vid2 = small_robot_mp4_vid.to_hdf5( + path, dataset="testvid3", frame_numbers=frame_indices + ) + + # Index by frame index in imgstore + frames = hdf5_vid2.get_frames(frame_indices) + assert frames.shape == (3, 320, 560, 3) + + assert hdf5_vid2.last_frame_idx == max(frame_indices) + + with pytest.raises(ValueError): + hdf5_vid2.get_frames([0, 1, 2]) diff --git a/tests/io/test_visuals.py b/tests/io/test_visuals.py new file mode 100644 index 000000000..9887a38c4 --- /dev/null +++ b/tests/io/test_visuals.py @@ -0,0 +1,14 @@ +import os +from sleap.io.visuals import save_labeled_video + + +def test_write_visuals(tmpdir, centered_pair_predictions): + path = os.path.join(tmpdir, "clip.avi") + save_labeled_video( + filename=path, + labels=centered_pair_predictions, + video=centered_pair_predictions.videos[0], + frames=(0, 1, 2), + fps=15, + ) + assert os.path.exists(path) diff --git a/tests/nn/test_datagen.py b/tests/nn/test_datagen.py index 97d0d5edd..a671d1a26 100644 --- a/tests/nn/test_datagen.py +++ b/tests/nn/test_datagen.py @@ -1,5 +1,6 @@ from sleap.nn.datagen import generate_images, generate_confidence_maps, generate_pafs + def test_datagen(min_labels): import numpy as np @@ -9,15 +10,14 @@ def test_datagen(min_labels): assert imgs.shape == (1, 384, 384, 1) assert imgs.dtype == np.dtype("float32") - assert math.isclose(np.ptp(imgs), .898, abs_tol=.01) + assert math.isclose(np.ptp(imgs), 0.898, abs_tol=0.01) confmaps = generate_confidence_maps(min_labels) assert confmaps.shape == (1, 384, 384, 2) assert confmaps.dtype == np.dtype("float32") - assert math.isclose(np.ptp(confmaps), .999, abs_tol=.01) - + assert math.isclose(np.ptp(confmaps), 0.999, abs_tol=0.01) pafs = generate_pafs(min_labels) assert pafs.shape == (1, 384, 384, 2) assert pafs.dtype == np.dtype("float32") - assert math.isclose(np.ptp(pafs), 1.57, abs_tol=.01) \ No newline at end of file + assert math.isclose(np.ptp(pafs), 1.57, abs_tol=0.01) diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 21f5a2f76..2641b6c8a 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -12,6 +12,7 @@ from sleap.io.dataset import Labels + def check_labels(labels): # Make sure there are 1100 frames @@ -19,7 +20,7 @@ def check_labels(labels): for i in labels.all_instances: assert type(i) == PredictedInstance - assert type(i.points()[0]) == PredictedPoint + assert type(i.points[0]) == PredictedPoint # Make sure frames are in order for i, frame in enumerate(labels): @@ -33,15 +34,18 @@ def check_labels(labels): # FIXME: We need more checks here. + def test_load_old_json(): - labels = load_predicted_labels_json_old("tests/data/json_format_v1/centered_pair.json") + old_json_filename = "tests/data/json_format_v1/centered_pair.json" + labels = Labels(load_predicted_labels_json_old(old_json_filename)) check_labels(labels) - #Labels.save_json(labels, 'tests/data/json_format_v2/centered_pair_predictions.json') + # Labels.save_json(labels, 'tests/data/json_format_v2/centered_pair_predictions.json') + def test_save_load_json(centered_pair_predictions, tmpdir): - test_out_file = os.path.join(tmpdir, 'test_tmp.json') + test_out_file = os.path.join(tmpdir, "test_tmp.json") # Check the labels check_labels(centered_pair_predictions) @@ -54,20 +58,21 @@ def test_save_load_json(centered_pair_predictions, tmpdir): check_labels(new_labels) + def test_peaks_with_scaling(): # load from scratch so we won't change centered_pair_predictions - true_labels = Labels.load_json('tests/data/json_format_v1/centered_pair.json') + true_labels = Labels.load_json("tests/data/json_format_v1/centered_pair.json") # only use a few frames true_labels.labeled_frames = true_labels.labeled_frames[13:23:2] skeleton = true_labels.skeletons[0] imgs = generate_images(true_labels) # scaling - scale = .5 + scale = 0.5 transform = DataTransform() img_size = imgs.shape[1], imgs.shape[2] - scaled_size = int(imgs.shape[1]//(1/scale)), int(imgs.shape[2]//(1/scale)) + scaled_size = int(imgs.shape[1] // (1 / scale)), int(imgs.shape[2] // (1 / scale)) imgs = transform.scale_to(imgs, scaled_size) assert transform.scale == scale assert imgs.shape[1], imgs.shape[2] == scaled_size @@ -83,15 +88,24 @@ def test_peaks_with_scaling(): # make sure what we got from interence matches what we started with for i in range(len(new_labels.labeled_frames)): - assert len(true_labels.labeled_frames[i].instances) <= len(new_labels.labeled_frames[i].instances) + assert len(true_labels.labeled_frames[i].instances) <= len( + new_labels.labeled_frames[i].instances + ) # sort instances by location of thorax true_labels.labeled_frames[i].instances.sort(key=lambda inst: inst["thorax"]) new_labels.labeled_frames[i].instances.sort(key=lambda inst: inst["thorax"]) # make sure that each true instance has points matching one of the new instances - for inst_a, inst_b in zip(true_labels.labeled_frames[i].instances, new_labels.labeled_frames[i].instances): - - assert inst_a.points_array().shape == inst_b.points_array().shape + for inst_a, inst_b in zip( + true_labels.labeled_frames[i].instances, + new_labels.labeled_frames[i].instances, + ): + + assert inst_a.get_points_array().shape == inst_b.get_points_array().shape # FIXME: new instances have nans, so for now just check first 5 points - assert np.allclose(inst_a.points_array()[0:5], inst_b.points_array()[0:5], atol=1/scale) + assert np.allclose( + inst_a.get_points_array()[0:5], + inst_b.get_points_array()[0:5], + atol=1 / scale, + ) diff --git a/tests/nn/test_tracking.py b/tests/nn/test_tracking.py index 7e56b1606..6f794f3e2 100644 --- a/tests/nn/test_tracking.py +++ b/tests/nn/test_tracking.py @@ -1,6 +1,7 @@ from sleap.nn.tracking import FlowShiftTracker from sleap.io.dataset import Labels + def test_flow_tracker(centered_pair_vid, centered_pair_predictions): # We are going to test tracking. The dataset we have loaded @@ -24,6 +25,7 @@ def test_flow_tracker(centered_pair_vid, centered_pair_predictions): # shouldn't be hard coded. assert len(tracks) == 24 + # def test_tracking_optflow_fail(centered_pair_vid, centered_pair_predictions): # frame_nums = range(0, len(centered_pair_predictions), 2) # labels = Labels([centered_pair_predictions[i] for i in frame_nums]) diff --git a/tests/nn/test_training.py b/tests/nn/test_training.py index e583538c2..24f4d8e5a 100644 --- a/tests/nn/test_training.py +++ b/tests/nn/test_training.py @@ -7,21 +7,29 @@ from sleap.nn.architectures.leap import leap_cnn from sleap.nn.training import Trainer, TrainingJob + def test_model_fail_non_available_backbone(multi_skel_vid_labels): with pytest.raises(ValueError): - Model(output_type=ModelOutputType.CONFIDENCE_MAP, backbone=object(), - skeletons=multi_skel_vid_labels.skeletons) + Model( + output_type=ModelOutputType.CONFIDENCE_MAP, + backbone=object(), + skeletons=multi_skel_vid_labels.skeletons, + ) @pytest.mark.parametrize("backbone", available_archs) def test_training_job_json(tmpdir, multi_skel_vid_labels, backbone): - run_name = 'training' + run_name = "training" - model = Model(output_type=ModelOutputType.CONFIDENCE_MAP, backbone=backbone(), - skeletons=multi_skel_vid_labels.skeletons) + model = Model( + output_type=ModelOutputType.CONFIDENCE_MAP, + backbone=backbone(), + skeletons=multi_skel_vid_labels.skeletons, + ) - train_run = TrainingJob(model=model, trainer=Trainer(), - save_dir=os.path.join(tmpdir), run_name=run_name) + train_run = TrainingJob( + model=model, trainer=Trainer(), save_dir=os.path.join(tmpdir), run_name=run_name + ) # Create and serialize training info json_path = os.path.join(tmpdir, f"{run_name}.json") @@ -30,10 +38,12 @@ def test_training_job_json(tmpdir, multi_skel_vid_labels, backbone): # Load the JSON back in loaded_run = TrainingJob.load_json(json_path) - assert loaded_run == train_run - - # Make sure the skeletons match too, not sure what the difference - # between __eq__ and matches on skeleton is at this point. + # Make sure the skeletons match (even though not eq) for sk1, sk2 in zip(loaded_run.model.skeletons, train_run.model.skeletons): assert sk1.matches(sk2) + # Now remove the skeletons since we want to check eq on everything else + loaded_run.model.skeletons = [] + train_run.model.skeletons = [] + + assert loaded_run == train_run diff --git a/tests/test_instance.py b/tests/test_instance.py index dbe402cad..d368ae477 100644 --- a/tests/test_instance.py +++ b/tests/test_instance.py @@ -8,6 +8,7 @@ from sleap.skeleton import Skeleton from sleap.instance import Instance, Point, LabeledFrame + def test_instance_node_get_set_item(skeleton): """ Test basic get item and set item functionality of instances. @@ -32,7 +33,7 @@ def test_instance_node_multi_get_set_item(skeleton): Test basic get item and set item functionality of instances. """ node_names = ["left-wing", "head", "right-wing"] - points = {"head": Point(1, 4), "left-wing": Point(2, 5), "right-wing": Point(3,6)} + points = {"head": Point(1, 4), "left-wing": Point(2, 5), "right-wing": Point(3, 6)} instance1 = Instance(skeleton=skeleton, points=points) @@ -55,7 +56,7 @@ def test_non_exist_node(skeleton): instance["non-existent-node"].x = 1 with pytest.raises(KeyError): - instance = Instance(skeleton=skeleton, points = {"non-exist": Point()}) + instance = Instance(skeleton=skeleton, points={"non-exist": Point()}) def test_instance_point_iter(skeleton): @@ -67,9 +68,9 @@ def test_instance_point_iter(skeleton): instance = Instance(skeleton=skeleton, points=points) - assert [node.name for node in instance.nodes] == ['head', 'left-wing', 'right-wing'] - assert np.allclose([p.x for p in instance.points()], [1, 2, 3]) - assert np.allclose([p.y for p in instance.points()], [4, 5, 6]) + assert [node.name for node in instance.nodes] == ["head", "left-wing", "right-wing"] + assert np.allclose([p.x for p in instance.points], [1, 2, 3]) + assert np.allclose([p.y for p in instance.points], [4, 5, 6]) # Make sure we can iterate over tuples for (node, point) in instance.nodes_points: @@ -83,28 +84,29 @@ def test_skeleton_node_name_change(): """ s = Skeleton("Test") - s.add_nodes(['a', 'b', 'c', 'd', 'e']) - s.add_edge('a', 'b') + s.add_nodes(["a", "b", "c", "d", "e"]) + s.add_edge("a", "b") instance = Instance(s) - instance['a'] = Point(1,2) - instance['b'] = Point(3,4) + instance["a"] = Point(1, 2) + instance["b"] = Point(3, 4) # Rename the node - s.relabel_nodes({'a': 'A'}) + s.relabel_nodes({"a": "A"}) # Reference to the old node name should raise a KeyError with pytest.raises(KeyError): - instance['a'].x = 2 + instance["a"].x = 2 # Make sure the A now references the same point on the instance - assert instance['A'] == Point(1, 2) - assert instance['b'] == Point(3, 4) + assert instance["A"] == Point(1, 2) + assert instance["b"] == Point(3, 4) + def test_instance_comparison(skeleton): node_names = ["left-wing", "head", "right-wing"] - points = {"head": Point(1, 4), "left-wing": Point(2, 5), "right-wing": Point(3,6)} + points = {"head": Point(1, 4), "left-wing": Point(2, 5), "right-wing": Point(3, 6)} instance1 = Instance(skeleton=skeleton, points=points) instance2 = copy.deepcopy(instance1) @@ -119,9 +121,10 @@ def test_instance_comparison(skeleton): assert not instance1.matches(instance2) instance2 = copy.deepcopy(instance1) - instance2.skeleton.add_node('extra_node') + instance2.skeleton.add_node("extra_node") assert not instance1.matches(instance2) + def test_points_array(skeleton): """ Test conversion of instances to points array""" @@ -130,29 +133,49 @@ def test_points_array(skeleton): instance1 = Instance(skeleton=skeleton, points=points) - pts = instance1.points_array() + pts = instance1.get_points_array() assert pts.shape == (len(skeleton.nodes), 2) - assert np.allclose(pts[skeleton.node_to_index('left-wing'), :], [2, 5]) - assert np.allclose(pts[skeleton.node_to_index('head'), :], [1, 4]) - assert np.allclose(pts[skeleton.node_to_index('right-wing'), :], [3, 6]) - assert np.isnan(pts[skeleton.node_to_index('thorax'), :]).all() + assert np.allclose(pts[skeleton.node_to_index("left-wing"), :], [2, 5]) + assert np.allclose(pts[skeleton.node_to_index("head"), :], [1, 4]) + assert np.allclose(pts[skeleton.node_to_index("right-wing"), :], [3, 6]) + assert np.isnan(pts[skeleton.node_to_index("thorax"), :]).all() # Now change a point, make sure it is reflected - instance1['head'].x = 0 - instance1['thorax'] = Point(1, 2) - pts = instance1.points_array() - assert np.allclose(pts[skeleton.node_to_index('head'), :], [0, 4]) - assert np.allclose(pts[skeleton.node_to_index('thorax'), :], [1, 2]) + instance1["head"].x = 0 + instance1["thorax"] = Point(1, 2) + pts = instance1.get_points_array() + assert np.allclose(pts[skeleton.node_to_index("head"), :], [0, 4]) + assert np.allclose(pts[skeleton.node_to_index("thorax"), :], [1, 2]) # Make sure that invisible points are nan iff invisible_as_nan=True - instance1['thorax'] = Point(1, 2, visible=False) + instance1["thorax"] = Point(1, 2, visible=False) + + pts = instance1.get_points_array() + assert not np.isnan(pts[skeleton.node_to_index("thorax"), :]).all() + + pts = instance1.points_array + assert np.isnan(pts[skeleton.node_to_index("thorax"), :]).all() + + +def test_modifying_skeleton(skeleton): + node_names = ["left-wing", "head", "right-wing"] + points = {"head": Point(1, 4), "left-wing": Point(2, 5), "right-wing": Point(3, 6)} + + instance1 = Instance(skeleton=skeleton, points=points) + + assert len(instance1.points) == 3 - pts = instance1.points_array() - assert not np.isnan(pts[skeleton.node_to_index('thorax'), :]).all() + skeleton.add_node("new test node") + + instance1.points # this updates instance with changes from skeleton + instance1["new test node"] = Point(7, 8) + + assert len(instance1.points) == 4 + + skeleton.delete_node("head") + assert len(instance1.points) == 3 - pts = instance1.points_array(invisible_as_nan=True) - assert np.isnan(pts[skeleton.node_to_index('thorax'), :]).all() def test_instance_labeled_frame_ref(skeleton, centered_pair_vid): """ @@ -165,4 +188,3 @@ def test_instance_labeled_frame_ref(skeleton, centered_pair_vid): assert frame.instances[0].frame == frame assert frame[0].frame == frame assert frame[0].frame_idx == 0 - diff --git a/tests/test_point_array.py b/tests/test_point_array.py index 49d452064..9410b52b5 100644 --- a/tests/test_point_array.py +++ b/tests/test_point_array.py @@ -3,8 +3,16 @@ from sleap.instance import Point, PredictedPoint, PointArray, PredictedPointArray -@pytest.mark.parametrize("p1", [Point(0.0, 0.0), PredictedPoint(0.0, 0.0, 0.0), - PointArray(3)[0], PredictedPointArray(3)[0]]) + +@pytest.mark.parametrize( + "p1", + [ + Point(0.0, 0.0), + PredictedPoint(0.0, 0.0, 0.0), + PointArray(3)[0], + PredictedPointArray(3)[0], + ], +) def test_point(p1): """ Test the Point and PredictedPoint API. This is mainly a safety @@ -40,15 +48,15 @@ def test_constructor(): assert p.score == 0.3 -@pytest.mark.parametrize('parray_cls', [PointArray, PredictedPointArray]) +@pytest.mark.parametrize("parray_cls", [PointArray, PredictedPointArray]) def test_point_array(parray_cls): p = parray_cls(5) # Make sure length works assert len(p) == 5 - assert len(p['x']) == 5 - assert len(p[['x', 'y']]) == 5 + assert len(p["x"]) == 5 + assert len(p[["x", "y"]]) == 5 # Check that single point getitem returns a Point class if parray_cls is PredictedPointArray: @@ -69,7 +77,10 @@ def test_point_array(parray_cls): # I have to convert from structured to unstructured to get this comparison # to work. from numpy.lib.recfunctions import structured_to_unstructured - np.testing.assert_array_equal(structured_to_unstructured(d1), structured_to_unstructured(d2)) + + np.testing.assert_array_equal( + structured_to_unstructured(d1), structured_to_unstructured(d2) + ) def test_from_and_to_array(): @@ -79,9 +90,11 @@ def test_from_and_to_array(): r = PredictedPointArray.to_array(PredictedPointArray.from_array(p)) from numpy.lib.recfunctions import structured_to_unstructured - np.testing.assert_array_equal(structured_to_unstructured(p), structured_to_unstructured(r)) + + np.testing.assert_array_equal( + structured_to_unstructured(p), structured_to_unstructured(r) + ) # Make sure conversion uses default score r = PredictedPointArray.from_array(p) assert r.score[0] == PredictedPointArray.make_default(1)[0].score - diff --git a/tests/test_rangelist.py b/tests/test_rangelist.py index 3f813362c..579b0d9a7 100644 --- a/tests/test_rangelist.py +++ b/tests/test_rangelist.py @@ -1,17 +1,30 @@ from sleap.rangelist import RangeList + def test_rangelist(): - a = RangeList([(1,2),(3,5),(7,13),(50,100)]) + a = RangeList([(1, 2), (3, 5), (7, 13), (50, 100)]) assert a.list == [(1, 2), (3, 5), (7, 13), (50, 100)] assert a.cut(8) == ([(1, 2), (3, 5), (7, 8)], [(8, 13), (50, 100)]) - assert a.cut_range((60,70)) == ([(1, 2), (3, 5), (7, 13), (50, 60)], [(60, 70)], [(70, 100)]) - assert a.insert((10,20)) == [(1, 2), (3, 5), (7, 20), (50, 100)] - assert a.insert((5,8)) == [(1, 2), (3, 20), (50, 100)] + assert a.cut_range((60, 70)) == ( + [(1, 2), (3, 5), (7, 13), (50, 60)], + [(60, 70)], + [(70, 100)], + ) + + # Test inserting range as tuple + assert a.insert((10, 20)) == [(1, 2), (3, 5), (7, 20), (50, 100)] + + # Test insert range as range + assert a.insert(range(5, 8)) == [(1, 2), (3, 20), (50, 100)] - a.remove((5,8)) + a.remove((5, 8)) assert a.list == [(1, 2), (3, 5), (8, 20), (50, 100)] - + + assert a.start == 1 + a.remove((1, 3)) + assert a.start == 3 + b = RangeList() b.add(1) b.add(2) @@ -21,4 +34,20 @@ def test_rangelist(): b.add(9) b.add(10) - assert b.list == [(1, 3), (4, 7), (9, 11)] \ No newline at end of file + assert b.list == [(1, 3), (4, 7), (9, 11)] + + empty = RangeList() + assert empty.start is None + assert empty.cut_range((3, 4)) == ([], [], []) + + empty.insert((1, 2)) + assert str(empty) == "RangeList([(1, 2)])" + + empty.insert_list([(1, 2), (3, 5), (7, 13), (50, 100)]) + assert empty.list == [(1, 2), (3, 5), (7, 13), (50, 100)] + + # Test special cases for helper functions + assert RangeList.join_([(1, 2)]) == (1, 2) + assert RangeList.join_pair_(list_a=[(1, 2)], list_b=[]) == [(1, 2)] + assert RangeList.join_pair_(list_a=[], list_b=[(1, 2)]) == [(1, 2)] + assert RangeList.join_pair_(list_a=[], list_b=[]) == [] diff --git a/tests/test_skeleton.py b/tests/test_skeleton.py index 1f2080836..ebb88721b 100644 --- a/tests/test_skeleton.py +++ b/tests/test_skeleton.py @@ -64,7 +64,8 @@ def test_getitem_node(skeleton): skeleton["non_exist_node"] # Now try to get the head node - assert(skeleton["head"] is not None) + assert skeleton["head"] is not None + def test_contains_node(skeleton): """ @@ -86,17 +87,16 @@ def test_node_rename(skeleton): skeleton["head"] # Make sure new head has the correct name - assert(skeleton["new_head_name"] is not None) + assert skeleton["new_head_name"] is not None def test_eq(): s1 = Skeleton("s1") - s1.add_nodes(['1','2','3','4','5','6']) - s1.add_edge('1', '2') - s1.add_edge('3', '4') - s1.add_edge('5', '6') - s1.add_symmetry('3', '6') - + s1.add_nodes(["1", "2", "3", "4", "5", "6"]) + s1.add_edge("1", "2") + s1.add_edge("3", "4") + s1.add_edge("5", "6") + s1.add_symmetry("3", "6") # Make a copy check that they are equal s2 = copy.deepcopy(s1) @@ -104,22 +104,22 @@ def test_eq(): # Add an edge, check that they are not equal s2 = copy.deepcopy(s1) - s2.add_edge('5', '1') + s2.add_edge("5", "1") assert not s1.matches(s2) # Add a symmetry edge, not equal s2 = copy.deepcopy(s1) - s2.add_symmetry('5', '1') + s2.add_symmetry("5", "1") assert not s1.matches(s2) # Delete a node s2 = copy.deepcopy(s1) - s2.delete_node('5') + s2.delete_node("5") assert not s1.matches(s2) # Delete and edge, not equal s2 = copy.deepcopy(s1) - s2.delete_edge('1', '2') + s2.delete_edge("1", "2") assert not s1.matches(s2) # FIXME: Probably shouldn't test it this way. @@ -133,14 +133,15 @@ def test_eq(): # s2._graph.nodes['1']['test'] = 5 # assert s1 != s2 + def test_symmetry(): s1 = Skeleton("s1") - s1.add_nodes(['1','2','3','4','5','6']) - s1.add_edge('1', '2') - s1.add_edge('3', '4') - s1.add_edge('5', '6') - s1.add_symmetry('1', '5') - s1.add_symmetry('3', '6') + s1.add_nodes(["1", "2", "3", "4", "5", "6"]) + s1.add_edge("1", "2") + s1.add_edge("3", "4") + s1.add_edge("5", "6") + s1.add_symmetry("1", "5") + s1.add_symmetry("3", "6") assert s1.get_symmetry("1").name == "5" assert s1.get_symmetry("5").name == "1" @@ -149,15 +150,22 @@ def test_symmetry(): # Cannot add more than one symmetry to a node with pytest.raises(ValueError): - s1.add_symmetry('1', '6') + s1.add_symmetry("1", "6") + with pytest.raises(ValueError): + s1.add_symmetry("6", "1") + + s1.delete_symmetry("1", "5") + assert s1.get_symmetry("1") is None + with pytest.raises(ValueError): - s1.add_symmetry('6', '1') + s1.delete_symmetry("1", "5") + def test_json(skeleton, tmpdir): """ Test saving and loading a Skeleton object in JSON. """ - JSON_TEST_FILENAME = os.path.join(tmpdir, 'skeleton.json') + JSON_TEST_FILENAME = os.path.join(tmpdir, "skeleton.json") # Save it to a JSON filename skeleton.save_json(JSON_TEST_FILENAME) @@ -166,11 +174,11 @@ def test_json(skeleton, tmpdir): skeleton_copy = Skeleton.load_json(JSON_TEST_FILENAME) # Make sure we get back the same skeleton we saved. - assert(skeleton.matches(skeleton_copy)) + assert skeleton.matches(skeleton_copy) def test_hdf5(skeleton, stickman, tmpdir): - filename = os.path.join(tmpdir, 'skeleton.h5') + filename = os.path.join(tmpdir, "skeleton.h5") if os.path.isfile(filename): os.remove(filename) @@ -202,7 +210,7 @@ def test_hdf5(skeleton, stickman, tmpdir): # Make sure we can't load a non-existent skeleton with pytest.raises(KeyError): - Skeleton.load_hdf5(filename, 'BadName') + Skeleton.load_hdf5(filename, "BadName") # Make sure we can't save skeletons with the same name with pytest.raises(ValueError): @@ -233,73 +241,83 @@ def dict_match(dict1, dict2): with pytest.raises(NotImplementedError): skeleton.name = "Test" + def test_graph_property(skeleton): assert [node for node in skeleton.graph.nodes()] == skeleton.nodes + def test_load_mat_format(): - skeleton = Skeleton.load_mat('tests/data/skeleton/leap_mat_format/skeleton_legs.mat') + skeleton = Skeleton.load_mat( + "tests/data/skeleton/leap_mat_format/skeleton_legs.mat" + ) # Check some stuff about the skeleton we loaded - assert(len(skeleton.nodes) == 24) - assert(len(skeleton.edges) == 23) + assert len(skeleton.nodes) == 24 + assert len(skeleton.edges) == 23 # The node and edge list that should be present in skeleton_legs.mat node_names = [ - 'head', - 'neck', - 'thorax', - 'abdomen', - 'wingL', - 'wingR', - 'forelegL1', - 'forelegL2', - 'forelegL3', - 'forelegR1', - 'forelegR2', - 'forelegR3', - 'midlegL1' , - 'midlegL2' , - 'midlegL3' , - 'midlegR1' , - 'midlegR2' , - 'midlegR3' , - 'hindlegL1', - 'hindlegL2', - 'hindlegL3', - 'hindlegR1', - 'hindlegR2', - 'hindlegR3'] + "head", + "neck", + "thorax", + "abdomen", + "wingL", + "wingR", + "forelegL1", + "forelegL2", + "forelegL3", + "forelegR1", + "forelegR2", + "forelegR3", + "midlegL1", + "midlegL2", + "midlegL3", + "midlegR1", + "midlegR2", + "midlegR3", + "hindlegL1", + "hindlegL2", + "hindlegL3", + "hindlegR1", + "hindlegR2", + "hindlegR3", + ] edges = [ - [ 2, 1], - [ 1, 0], - [ 2, 3], - [ 2, 4], - [ 2, 5], - [ 2, 6], - [ 6, 7], - [ 7, 8], - [ 2, 9], - [ 9, 10], - [10, 11], - [ 2, 12], - [12, 13], - [13, 14], - [ 2, 15], - [15, 16], - [16, 17], - [ 2, 18], - [18, 19], - [19, 20], - [ 2, 21], - [21, 22], - [22, 23]] + [2, 1], + [1, 0], + [2, 3], + [2, 4], + [2, 5], + [2, 6], + [6, 7], + [7, 8], + [2, 9], + [9, 10], + [10, 11], + [2, 12], + [12, 13], + [13, 14], + [2, 15], + [15, 16], + [16, 17], + [2, 18], + [18, 19], + [19, 20], + [2, 21], + [21, 22], + [22, 23], + ] assert [n.name for n in skeleton.nodes] == node_names # Check the edges and their order for i, edge in enumerate(skeleton.edge_names): - assert tuple(edges[i]) == (skeleton.node_to_index(edge[0]), skeleton.node_to_index(edge[1])) + assert tuple(edges[i]) == ( + skeleton.node_to_index(edge[0]), + skeleton.node_to_index(edge[1]), + ) + def test_edge_order(): """Test is edge list order is maintained upon insertion""" diff --git a/tests/test_util.py b/tests/test_util.py index 50ce8da83..e72a78bc2 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -4,13 +4,26 @@ from typing import List, Dict -from sleap.util import attr_to_dtype +from sleap.util import ( + json_dumps, + json_loads, + attr_to_dtype, + frame_list, + weak_filename_match, +) + + +def test_json(): + original_dict = dict(key=123) + assert original_dict == json_loads(json_dumps(original_dict)) + def test_attr_to_dtype(): """ Test that we can convert classes with basic types to numpy composite dtypes. """ + @attr.s class TestAttr: a: int = attr.ib() @@ -30,10 +43,10 @@ class TestAttr3: c: Dict = attr.ib() # Dict should throw exception! dtype = attr_to_dtype(TestAttr) - dtype.fields['a'][0] == np.dtype(int) - dtype.fields['b'][0] == np.dtype(float) - dtype.fields['c'][0] == np.dtype(bool) - dtype.fields['d'][0] == np.dtype(object) + dtype.fields["a"][0] == np.dtype(int) + dtype.fields["b"][0] == np.dtype(float) + dtype.fields["c"][0] == np.dtype(bool) + dtype.fields["d"][0] == np.dtype(object) with pytest.raises(TypeError): attr_to_dtype(TestAttr2) @@ -41,3 +54,20 @@ class TestAttr3: with pytest.raises(TypeError): attr_to_dtype(TestAttr3) + +def test_frame_list(): + assert frame_list("3-5") == [3, 4, 5] + assert frame_list("7,10") == [7, 10] + + +def test_weak_match(): + assert weak_filename_match("one/two", "one/two") + assert weak_filename_match( + "M:\\code\\sandbox\\sleap_nas\\pilot_6pts\\tmp_11576_FoxP1_6pts.training.n=468.json.zip\\frame_data_vid0\\metadata.yaml", + "D:\\projects\\code\\sandbox\\sleap_nas\\pilot_6pts\\tmp_99713_FoxP1_6pts.training.n=468.json.zip\\frame_data_vid0\\metadata.yaml", + ) + assert weak_filename_match("zero/one/two/three.mp4", "other\\one\\two\\three.mp4") + + assert not weak_filename_match("one/two/three", "two/three") + assert not weak_filename_match("one/two/three.mp4", "one/two/three.avi") + assert not weak_filename_match("foo.mp4", "bar.mp4")