Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
ntabris committed Jul 15, 2020
2 parents a370a2f + c0f57aa commit 59786f0
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 14 deletions.
43 changes: 33 additions & 10 deletions sleap/gui/dialogs/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pandas as pd
import seaborn as sns

from sleap import Labels, Skeleton
from sleap.gui.dataviews import GenericTableModel, GenericTableView
from sleap.gui.dialogs.filedialog import FileDialog
from sleap.gui.learning.configs import TrainingConfigsGetter, ConfigFileInfo
Expand All @@ -19,13 +18,12 @@ class MetricsTableDialog(QtWidgets.QWidget):
def __init__(self, labels_filename: Text):
super(MetricsTableDialog, self).__init__()

labels = Labels.load_file(labels_filename)
self.skeleton = labels.skeletons[0]
labels_filename = labels_filename or ""

self._cfg_getter = TrainingConfigsGetter.make_from_labels_filename(
labels_filename,
)
self._cfg_getter.search_depth = 2
self._cfg_getter.search_depth = 4

self.table_model = MetricsTableModel(items=[])
self.table_view = GenericTableView(
Expand Down Expand Up @@ -110,7 +108,7 @@ def _show_metric_details(

key = cfg_info.path
if key not in metric_detail_widgets:
metric_detail_widgets[key] = DetailedMetricsDialog(cfg_info, self.skeleton)
metric_detail_widgets[key] = DetailedMetricsDialog(cfg_info)

metric_detail_widgets[key].show()
metric_detail_widgets[key].raise_()
Expand Down Expand Up @@ -193,15 +191,37 @@ def item_to_data(self, obj, cfg: ConfigFileInfo):
return item_data


METRICS_KEY_LABELS = {
"vis.tp": "Visibility - True Positives",
"vis.fp": "Visibility - False Positives",
"vis.tn": "Visibility - True Negatives",
"vis.fn": "Visibility - False Negatives",
"vis.precision": "Visibility - Precision",
"vis.recall": "Visibility - Recall",
"dist.avg": "Average Distance (ground truth vs prediction)",
"dist.p50": "Distance for 50th percentile",
"dist.p75": "Distance for 75th percentile",
"dist.p90": "Distance for 90th percentile",
"dist.p95": "Distance for 95th percentile",
"dist.p99": "Distance for 99th percentile",
"pck.mPCK": "Mean Percentage of Correct Keypoints (PCK)",
"oks.mOKS": "Mean Object Keypoint Similarity (OKS)",
"oks_voc.mAP": "VOC with OKS scores - mean Average Precision (mAP)",
"oks_voc.mAR": "VOC with OKS scores - mean Average Recall (mAR)",
"pck_voc.mAP": "VOC with PCK scores - mean Average Precision (mAP)",
"pck_voc.mAR": "VOC with PCK scores - mean Average Recall (mAR)",
}


class DetailedMetricsDialog(QtWidgets.QWidget):
def __init__(self, cfg_info: ConfigFileInfo, skeleton: Skeleton):
def __init__(self, cfg_info: ConfigFileInfo):
super(DetailedMetricsDialog, self).__init__()

self.setWindowTitle(cfg_info.path_dir)
self.setMinimumWidth(800)

self.cfg_info = cfg_info
self.skeleton = skeleton
self.skeleton = cfg_info.skeleton

self.metrics = self.cfg_info.metrics

Expand All @@ -217,9 +237,13 @@ def __init__(self, cfg_info: ConfigFileInfo, skeleton: Skeleton):
):
val_str = str(val)

key_str = (
METRICS_KEY_LABELS[key] if key in METRICS_KEY_LABELS else key
)

text_widget = QtWidgets.QLabel(val_str)
text_widget.setTextInteractionFlags(QtCore.Qt.TextSelectableByMouse)
metrics_layout.addRow(f"<b>{key}</b>:", text_widget)
metrics_layout.addRow(f"<b>{key_str}</b>:", text_widget)

metrics_widget = QtWidgets.QWidget()
metrics_widget.setLayout(metrics_layout)
Expand All @@ -236,8 +260,7 @@ def __init__(self, cfg_info: ConfigFileInfo, skeleton: Skeleton):
def plot_distances(self):
ax = self.canvas.axes

# node_names = self.cfg_info.config.data.labels.skeletons[0].node_names
node_names = self.skeleton.node_names
node_names = self.skeleton.node_names if self.skeleton else None

dists = pd.DataFrame(self.metrics["dist.dists"], columns=node_names).melt(
var_name="Part", value_name="Error"
Expand Down
26 changes: 26 additions & 0 deletions sleap/gui/learning/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import numpy as np

from sleap import Labels, Skeleton
from sleap import util as sleap_utils
from sleap.gui.dialogs.filedialog import FileDialog
from sleap.nn.config import TrainingJobConfig
Expand All @@ -22,6 +23,8 @@ class ConfigFileInfo:
filename: Optional[Text] = None
head_name: Optional[Text] = None
dont_retrain: bool = False
_skeleton: Optional[Skeleton] = None
_tried_finding_skeleton: bool = False
_dset_len_cache: dict = attr.ib(factory=dict)

@property
Expand Down Expand Up @@ -63,6 +66,28 @@ def _get_file_path(self, shortname) -> Optional[Text]:
def metrics(self):
return self._get_metrics("val")

@property
def skeleton(self):
# cache skeleton so we only search once
if self._skeleton is None and not self._tried_finding_skeleton:

# if skeleton was saved in config, great!
if self.config.data.labels.skeletons:
self._skeleton = self.config.data.labels.skeletons[0]

# otherwise try loading it from validation labels (much slower!)
else:
filename = self._get_file_path(f"labels_gt.val.slp")
if filename is not None:
val_labels = Labels.load_file(filename)
if val_labels.skeletons:
self._skeleton = val_labels.skeletons[0]

# don't try loading again (needed in case it's still None)
self._tried_finding_skeleton = True

return self._skeleton

@property
def training_instance_count(self):
return self._get_dataset_len("instances", "train")
Expand Down Expand Up @@ -310,6 +335,7 @@ def get_filtered_configs(
# if files were copied

cfg_dir = os.path.dirname(cfg_info.path)

if cfg_dir not in paths_included:
paths_included.append(cfg_dir)
cfgs_to_return.append(cfg_info)
Expand Down
12 changes: 9 additions & 3 deletions sleap/io/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,22 +988,28 @@ def get_frames_safely(self, idxs: Iterable[int]) -> Tuple[List[int], np.ndarray]
Returns: A tuple of (frame indices, frames), where
* frame indices is a subset of the specified idxs, and
* frames has shape (len(frame indices), height, width, channels).
If zero frames were loaded successfully, then frames is None.
"""
frames = []
idxs_found = []

for idx in idxs:
try:
frame = self.get_frame(idx)
except:
# quietly ignore frames which we couldn't load
except Exception as e:
print(e)
# ignore frames which we couldn't load
frame = None

if frame is not None:
frames.append(frame)
idxs_found.append(idx)

frames = np.stack(frames, axis=0)
if frames:
frames = np.stack(frames, axis=0)
else:
frames = None

return idxs_found, frames

def __getitem__(self, idxs):
Expand Down
5 changes: 5 additions & 0 deletions sleap/io/visuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ def reader(out_q: Queue, video: Video, frames: List[int], scale: float = 1.0):
frames_idx_chunk
)

if not loaded_chunk_idxs:
print(f"No frames could be loaded from chunk {chunk_i}")
i += 1
continue

if scale != 1.0:
video_frame_images = resize_images(video_frame_images, scale)

Expand Down
2 changes: 1 addition & 1 deletion sleap/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.7"
__version__ = "1.0.8"
9 changes: 9 additions & 0 deletions tests/io/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,12 @@ def test_safe_frame_loading(small_robot_mp4_vid):

assert idxs == [1, 2]
assert len(frames) == 2


def test_safe_frame_loading_all_invalid():
vid = Video.from_filename("video_that_does_not_exist.mp4")

idxs, frames = vid.get_frames_safely(list(range(10)))

assert idxs == []
assert frames is None

0 comments on commit 59786f0

Please sign in to comment.