diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 6041d9f37..61a24c60f 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -93,7 +93,6 @@ def __init__(self, labels_path: Optional[str] = None, *args, **kwargs): self.state["show edges"] = True self.state["edge style"] = "Line" self.state["fit"] = False - self.state["show trails"] = False self.state["color predicted"] = False self._initialize_gui() @@ -389,12 +388,10 @@ def prev_vid(): key="edge style", ) - add_menu_check_item(viewMenu, "show trails", "Show Trails") - add_submenu_choices( menu=viewMenu, title="Trail Length", - options=(10, 20, 50), + options=(0, 10, 20, 50), key="trail_length", ) @@ -823,7 +820,6 @@ def overlay_state_connect(overlay, state_key, overlay_attribute=None): ], ) - overlay_state_connect(self.overlays["trails"], "show trails", "show") overlay_state_connect(self.overlays["trails"], "trail_length") overlay_state_connect(self.color_manager, "palette") @@ -838,7 +834,7 @@ def overlay_state_connect(overlay, state_key, overlay_attribute=None): ) # Set defaults - self.state["trail_length"] = 10 + self.state["trail_length"] = 0 # Emit signals for default that may have been set earlier self.state.emit("palette") @@ -984,9 +980,6 @@ def plotFrame(self, *args, **kwargs): self.player.plot() - if self.state["fit"]: - self.player.zoomToFit() - def _after_plot_update(self, player, frame_idx, selected_inst): """Called each time a new frame is drawn.""" @@ -1003,6 +996,9 @@ def _after_plot_update(self, player, frame_idx, selected_inst): if selected_inst is not None: player.view.selectInstance(selected_inst) + if self.state["fit"]: + player.zoomToFit() + # Update related displays self.updateStatusMessage() self.on_data_update([UpdateTopic.on_frame]) diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index d754513cd..45d5e0812 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -1453,7 +1453,7 @@ def do_action(context: CommandContext, params: dict): # Determine range that should be affected if context.state["has_frame_range"]: # If range is selected in seekbar, use that - frame_range = tuple(*context.state["frame_range"]) + frame_range = tuple(context.state["frame_range"]) else: # Otherwise, range is current to last frame frame_range = ( diff --git a/sleap/gui/overlays/tracks.py b/sleap/gui/overlays/tracks.py index 033efeb6a..a6bb6e358 100644 --- a/sleap/gui/overlays/tracks.py +++ b/sleap/gui/overlays/tracks.py @@ -36,8 +36,8 @@ class TrackTrailOverlay: labels: Labels = None player: "QtVideoPlayer" = None - trail_length: int = 10 - show: bool = False + trail_length: int = 0 + show: bool = True def get_track_trails(self, frame_selection: Iterable["LabeledFrame"]): """Get data needed to draw track trail. diff --git a/sleap/gui/slider.py b/sleap/gui/slider.py index 33ce33c71..18909630a 100644 --- a/sleap/gui/slider.py +++ b/sleap/gui/slider.py @@ -106,7 +106,7 @@ def visual_width(self): def get_height(self, container_height): if self.type == "track": - return 1.5 + return 2 height = container_height # if self.padded: height -= self.top_pad + self.bottom_pad diff --git a/sleap/gui/video.py b/sleap/gui/video.py index e5196d7bb..6c4c031a6 100644 --- a/sleap/gui/video.py +++ b/sleap/gui/video.py @@ -11,8 +11,17 @@ >>> vp.addInstance(instance=my_instance, color=(r, g, b)) """ +from collections import deque -FORCE_REQUEST_AFTER_TIME_IN_SECONDS = 1 + +# FORCE_REQUESTS controls whether we emit a signal to process frame requests +# if we haven't processed any for a certain amount of time. +# Usually the processing gets triggered by a timer but if the user is (e.g.) +# dragging the mouse, the timer doesn't trigger. +# FORCE_REQUESTS lets us update the frames in real time, assuming the load time +# is short enough to do that. + +FORCE_REQUESTS = True from PySide2 import QtWidgets, QtCore @@ -65,10 +74,14 @@ class LoadImageWorker(QtCore.QObject): load_queue = [] video = None _last_process_time = 0 + _force_request_wait_time = 1 + _recent_load_times = None def __init__(self, *args, **kwargs): super(LoadImageWorker, self).__init__(*args, **kwargs) + self._recent_load_times = deque(maxlen=5) + # Connect signal to processing function so that we can add processing # event to event queue from the request handler. self.process.connect(self.doProcessing) @@ -87,12 +100,21 @@ def doProcessing(self): frame_idx = self.load_queue[-1] self.load_queue = [] - # print(f"\t{frame_idx} starting to load") # DEBUG - try: + + t0 = time.time() + # Get image data frame = self.video.get_frame(frame_idx) - except: + + self._recent_load_times.append(time.time() - t0) + + # Set the time to wait before forcing a load request to a little + # longer than the average time it recently took to load a frame + avg_load_time = sum(self._recent_load_times) / len(self._recent_load_times) + self._force_request_wait_time = avg_load_time * 1.2 + + except Exception as e: frame = None if frame is not None: @@ -115,9 +137,10 @@ def request(self, frame_idx): since_last = time.time() - self._last_process_time - if since_last > FORCE_REQUEST_AFTER_TIME_IN_SECONDS: - self._last_process_time = time.time() - self.process.emit() + if FORCE_REQUESTS: + if since_last > self._force_request_wait_time: + self._last_process_time = time.time() + self.process.emit() class QtVideoPlayer(QWidget): diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 3acfb30a4..6bd600791 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -7,49 +7,36 @@ """ import os -import re -import zipfile -import atexit +from collections import MutableSequence +from typing import Callable, List, Union, Dict, Optional, Tuple, Text import attr import cattr -import shutil -import tempfile -import numpy as np -import scipy.io as sio import h5py as h5 - -from collections import MutableSequence -from typing import Callable, List, Union, Dict, Optional, Tuple, Text +import numpy as np try: from typing import ForwardRef except: from typing import _ForwardRef as ForwardRef -import pandas as pd - from sleap.skeleton import Skeleton, Node from sleap.instance import ( Instance, - Point, LabeledFrame, Track, - PredictedPoint, - PredictedInstance, make_instance_cattr, PointArray, PredictedPointArray, + PredictedInstance, ) from sleap.io import pathutils -from sleap.io.legacy import load_labels_json_old from sleap.io.video import Video from sleap.gui.suggestions import SuggestionFrame from sleap.gui.missingfiles import MissingFilesDialog from sleap.rangelist import RangeList -from sleap.util import uniquify, weak_filename_match, json_dumps, json_loads - +from sleap.util import uniquify, json_dumps """ The version number to put in the Labels JSON format. @@ -526,6 +513,11 @@ def user_instances(self): """Returns list of all user (non-predicted) instances.""" return [inst for inst in self.all_instances if type(inst) == Instance] + @property + def predicted_instances(self): + """Returns list of all user (non-predicted) instances.""" + return [inst for inst in self.all_instances if type(inst) == PredictedInstance] + def instances(self, video: Video = None, skeleton: Skeleton = None): """ Iterate over instances in the labels, optionally with filters. @@ -896,6 +888,12 @@ def remove_negative_anchors(self, video: Video, frame_idx: int): # Methods for saving/loading + @classmethod + def from_json(cls, *args, **kwargs): + from sleap.io.format.labels_json import LabelsJsonAdaptor + + return LabelsJsonAdaptor.from_json_data(*args, **kwargs) + def extend_from( self, new_frames: Union["Labels", List[LabeledFrame]], unify: bool = False ): @@ -1004,21 +1002,6 @@ def complex_merge_between( return merged, extra_base, extra_new - # @classmethod - # def merge_predictions_by_score(cls, extra_base: List[LabeledFrame], extra_new: List[LabeledFrame]): - # """ - # Remove all predictions from input lists, return list with only - # the merged predictions. - # - # Args: - # extra_base: list of `LabeledFrame` objects - # extra_new: list of `LabeledFrame` objects - # Conflicting frames should have same index in both lists. - # Returns: - # list of `LabeledFrame` objects with merged predictions - # """ - # pass - @staticmethod def finish_complex_merge( base_labels: "Labels", resolved_frames: List[LabeledFrame] @@ -1157,871 +1140,103 @@ def to_json(self): # Unstructure the data into dicts and dump to JSON. return json_dumps(self.to_dict()) - @staticmethod - def save_json( - labels: "Labels", - filename: str, - compress: bool = False, - save_frame_data: bool = False, - frame_data_format: str = "png", - ): - """ - Save a Labels instance to a JSON format. - - Args: - labels: The labels dataset to save. - filename: The filename to save the data to. - compress: Whether the data be zip compressed or not? If True, - the JSON will be compressed using Python's shutil.make_archive - command into a PKZIP zip file. If compress is True then - filename will have a .zip appended to it. - save_frame_data: Whether to save the image data for each frame. - For each video in the dataset, all frames that have labels - will be stored as an imgstore dataset. - If save_frame_data is True then compress will be forced to True - since the archive must contain both the JSON data and image - data stored in ImgStores. - frame_data_format: If save_frame_data is True, then this argument - is used to set the data format to use when writing frame - data to ImgStore objects. Supported formats should be: - - * 'pgm', - * 'bmp', - * 'ppm', - * 'tif', - * 'png', - * 'jpg', - * 'npy', - * 'mjpeg/avi', - * 'h264/mkv', - * 'avc1/mp4' - - Note: 'h264/mkv' and 'avc1/mp4' require separate installation - of these codecs on your system. They are excluded from SLEAP - because of their GPL license. - - Returns: - None - """ - - # Lets make a temporary directory to store the image frame data or pre-compressed json - # in case we need it. - with tempfile.TemporaryDirectory() as tmp_dir: - - # If we are saving frame data along with the datasets. We will replace videos with - # new video object that represent video data from just the labeled frames. - if save_frame_data: - - # Create a set of new Video objects with imgstore backends. One for each - # of the videos. We will only include the labeled frames though. We will - # then replace each video with this new video - new_videos = labels.save_frame_data_imgstore( - output_dir=tmp_dir, format=frame_data_format - ) - - # Make video paths relative - for vid in new_videos: - tmp_path = vid.filename - # Get the parent dir of the YAML file. - # Use "/" since this works on Windows and posix - img_store_dir = ( - os.path.basename(os.path.split(tmp_path)[0]) - + "/" - + os.path.basename(tmp_path) - ) - # Change to relative path - vid.backend.filename = img_store_dir - - # Convert to a dict, not JSON yet, because we need to patch up the videos - d = labels.to_dict() - d["videos"] = Video.cattr().unstructure(new_videos) - - else: - d = labels.to_dict() - - if compress or save_frame_data: - - # Ensure that filename ends with .json - # shutil will append .zip - filename = re.sub("(\.json)?(\.zip)?$", ".json", filename) - - # Write the json to the tmp directory, we will zip it up with the frame data. - full_out_filename = os.path.join(tmp_dir, os.path.basename(filename)) - json_dumps(d, full_out_filename) - - # Create the archive - shutil.make_archive(base_name=filename, root_dir=tmp_dir, format="zip") - - # If the user doesn't want to compress, then just write the json to the filename - else: - json_dumps(d, filename) - @classmethod - def from_json( - cls, data: Union[str, dict], match_to: Optional["Labels"] = None - ) -> "Labels": - """ - Create instance of class from data in dictionary. - - Method is used by other methods that load from JSON. - - Args: - data: Dictionary, deserialized from JSON. - match_to: If given, we'll replace particular objects in the - data dictionary with *matching* objects in the match_to - :class:`Labels` object. This ensures that the newly - instantiated :class:`Labels` can be merged without - duplicate matching objects (e.g., :class:`Video` objects ). - Returns: - A new :class:`Labels` object. - """ - - # Parse the json string if needed. - if type(data) is str: - dicts = json_loads(data) - else: - dicts = data - - dicts["tracks"] = dicts.get( - "tracks", [] - ) # don't break if json doesn't include tracks - - # First, deserialize the skeletons, videos, and nodes lists. - # The labels reference these so we will need them while deserializing. - nodes = cattr.structure(dicts["nodes"], List[Node]) - - idx_to_node = {i: nodes[i] for i in range(len(nodes))} - skeletons = Skeleton.make_cattr(idx_to_node).structure( - dicts["skeletons"], List[Skeleton] - ) - videos = Video.cattr().structure(dicts["videos"], List[Video]) - - try: - # First try unstructuring tuple (newer format) - track_cattr = cattr.Converter( - unstruct_strat=cattr.UnstructureStrategy.AS_TUPLE - ) - tracks = track_cattr.structure(dicts["tracks"], List[Track]) - except: - # Then try unstructuring dict (older format) - try: - tracks = cattr.structure(dicts["tracks"], List[Track]) - except: - raise ValueError("Unable to load tracks as tuple or dict!") - - # if we're given a Labels object to match, use its objects when they match - if match_to is not None: - for idx, sk in enumerate(skeletons): - for old_sk in match_to.skeletons: - if sk.matches(old_sk): - # use nodes from matched skeleton - for (node, match_node) in zip(sk.nodes, old_sk.nodes): - node_idx = nodes.index(node) - nodes[node_idx] = match_node - # use skeleton from match - skeletons[idx] = old_sk - break - for idx, vid in enumerate(videos): - for old_vid in match_to.videos: - # compare last three parts of path - if vid.filename == old_vid.filename or weak_filename_match( - vid.filename, old_vid.filename - ): - # use video from match - videos[idx] = old_vid - break - - suggestions = [] - if "suggestions" in dicts: - suggestions_cattr = cattr.Converter() - suggestions_cattr.register_structure_hook( - Video, lambda x, type: videos[int(x)] - ) - try: - suggestions = suggestions_cattr.structure( - dicts["suggestions"], List[SuggestionFrame] - ) - except Exception as e: - print("Error while loading suggestions (1)") - print(e) - - try: - # Convert old suggestion format to new format. - # Old format: {video: list of frame indices} - # New format: [SuggestionFrames] - old_suggestions = suggestions_cattr.structure( - dicts["suggestions"], Dict[Video, List] - ) - for video in old_suggestions.keys(): - suggestions.extend( - [ - SuggestionFrame(video, idx) - for idx in old_suggestions[video] - ] - ) - except Exception as e: - print("Error while loading suggestions (2)") - print(e) - pass - - if "negative_anchors" in dicts: - negative_anchors_cattr = cattr.Converter() - negative_anchors_cattr.register_structure_hook( - Video, lambda x, type: videos[int(x)] - ) - negative_anchors = negative_anchors_cattr.structure( - dicts["negative_anchors"], Dict[Video, List] - ) - else: - negative_anchors = dict() - - # If there is actual labels data, get it. - if "labels" in dicts: - label_cattr = make_instance_cattr() - label_cattr.register_structure_hook( - Skeleton, lambda x, type: skeletons[int(x)] - ) - label_cattr.register_structure_hook(Video, lambda x, type: videos[int(x)]) - label_cattr.register_structure_hook( - Node, lambda x, type: x if isinstance(x, Node) else nodes[int(x)] - ) - label_cattr.register_structure_hook( - Track, lambda x, type: None if x is None else tracks[int(x)] - ) + def load_file(cls, filename: str, *args, **kwargs): + """Load file, detecting format from filename.""" + from .format import read - labels = label_cattr.structure(dicts["labels"], List[LabeledFrame]) - else: - labels = [] - - return cls( - labeled_frames=labels, - videos=videos, - skeletons=skeletons, - nodes=nodes, - suggestions=suggestions, - negative_anchors=negative_anchors, - tracks=tracks, - ) + return read(filename, for_object="labels", *args, **kwargs) @classmethod - def load_json( - cls, - filename: str, - video_callback: Optional[Callable] = None, - match_to: Optional["Labels"] = None, - ) -> "Labels": - """ - Deserialize JSON file as new :class:`Labels` instance. - - Args: - filename: Path to JSON file. - video_callback: A callback function that which can modify - video paths before we try to create the corresponding - :class:`Video` objects. Usually you'll want to pass - a callback created by :meth:`make_video_callback` - or :meth:`make_gui_video_callback`. - Alternately, if you pass a list of strings we'll construct a - non-gui callback with those strings as the search paths. - match_to: If given, we'll replace particular objects in the - data dictionary with *matching* objects in the match_to - :class:`Labels` object. This ensures that the newly - instantiated :class:`Labels` can be merged without - duplicate matching objects (e.g., :class:`Video` objects ). - Returns: - A new :class:`Labels` object. - """ - - tmp_dir = None - - # Check if the file is a zipfile for not. - if zipfile.is_zipfile(filename): - - # Make a tmpdir, located in the directory that the file exists, to unzip - # its contents. - tmp_dir = os.path.join( - os.path.dirname(filename), - f"tmp_{os.getpid()}_{os.path.basename(filename)}", - ) - if os.path.exists(tmp_dir): - shutil.rmtree(tmp_dir, ignore_errors=True) - try: - os.mkdir(tmp_dir) - except FileExistsError: - pass - - # tmp_dir = tempfile.mkdtemp(dir=os.path.dirname(filename)) - - try: - - # Register a cleanup routine that deletes the tmpdir on program exit - # if something goes wrong. The True is for ignore_errors - atexit.register(shutil.rmtree, tmp_dir, True) - - # Uncompress the data into the directory - shutil.unpack_archive(filename, extract_dir=tmp_dir) - - # We can now open the JSON file, save the zip file and - # replace file with the first JSON file we find in the archive. - json_files = [ - os.path.join(tmp_dir, file) - for file in os.listdir(tmp_dir) - if file.endswith(".json") - ] - - if len(json_files) == 0: - raise ValueError( - f"No JSON file found inside {filename}. Are you sure this is a valid sLEAP dataset." - ) - - filename = json_files[0] - - except Exception as ex: - # If we had problems, delete the temp directory and reraise the exception. - shutil.rmtree(tmp_dir, ignore_errors=True) - raise - - # Open and parse the JSON in filename - with open(filename, "r") as file: - - # FIXME: Peek into the json to see if there is version string. - # We do this to tell apart old JSON data from leap_dev vs the - # newer format for sLEAP. - json_str = file.read() - dicts = json_loads(json_str) - - # If we have a version number, then it is new sLEAP format - if "version" in dicts: - - # Cache the working directory. - cwd = os.getcwd() - # Replace local video paths (for imagestore) - if tmp_dir: - for vid in dicts["videos"]: - vid["backend"]["filename"] = os.path.join( - tmp_dir, vid["backend"]["filename"] - ) - - if hasattr(video_callback, "__iter__"): - # If the callback is an iterable, then we'll expect it to be a - # list of strings and build a non-gui callback with those as - # the search paths. - search_paths = [path for path in video_callback] - video_callback = cls.make_video_callback(search_paths) - - # Use the callback if given to handle missing videos - if callable(video_callback): - abort = video_callback(dicts["videos"]) - if abort: - raise FileNotFoundError - - # Try to load the labels filename. - try: - labels = Labels.from_json(dicts, match_to=match_to) - - except FileNotFoundError: - - # FIXME: We are going to the labels JSON that has references to - # video files. Lets change directory to the dirname of the json file - # so that relative paths will be from this directory. Maybe - # it is better to feed the dataset dirname all the way down to - # the Video object. This seems like less coupling between classes - # though. - if os.path.dirname(filename) != "": - os.chdir(os.path.dirname(filename)) - - # Try again - labels = Labels.from_json(dicts, match_to=match_to) - - except Exception as ex: - # Ok, we give up, where the hell are these videos! - raise # Re-raise. - finally: - os.chdir(cwd) # Make sure to change back if we have problems. - - return labels - - else: - frames = load_labels_json_old(data_path=filename, parsed_json=dicts) - return Labels(frames) - - @staticmethod - def save_hdf5( - labels: "Labels", - filename: str, - append: bool = False, - save_frame_data: bool = False, - frame_data_format: str = "png", + def save_file( + cls, labels: "Labels", filename: str, default_suffix: str = "", *args, **kwargs ): - """ - Serialize the labels dataset to an HDF5 file. + """Save file, detecting format from filename. Args: - labels: The :class:`Labels` dataset to save - filename: The file to serialize the dataset to. - append: Whether to append these labeled frames to the file - or not. - save_frame_data: Whether to save the image frame data for - any labeled frame as well. This is useful for uploading - the HDF5 for model training when video files are to - large to move. This will only save video frames that - have some labeled instances. - frame_data_format: If save_frame_data is True, then this argument - is used to set the data format to use when encoding images - saved in HDF5. Supported formats include: - - * "" for no encoding (ndarray) - * "png" - * "jpg" - * anything else supported by `cv2.imencode` + labels: The dataset to save. + filename: Path where we'll save it. We attempt to detect format + from the suffix (e.g., ".json"). + default_suffix: If we can't detect valid suffix on filename, + we can add default suffix to filename (and use corresponding + format). Doesn't need to have "." before file extension. + + Raises: + ValueError: If cannot detect valid filetype. Returns: - None + None. """ + # Convert to full (absolute) path + filename = os.path.abspath(filename) - # Delete the file if it exists, we want to start from scratch since - # h5py truncates the file which seems to not actually delete data - # from the file. Don't if we are appending of course. - if os.path.exists(filename) and not append: - os.unlink(filename) - - # Serialize all the meta-data to JSON. - d = labels.to_dict(skip_labels=True) - - if save_frame_data: - new_videos = labels.save_frame_data_hdf5(filename, frame_data_format) - - # Replace path to video file with "." (which indicates that the - # video is in the same file as the HDF5 labels dataset). - # Otherwise, the video paths will break if the HDF5 labels - # dataset file is moved. - for vid in new_videos: - vid.backend.filename = "." - - d["videos"] = Video.cattr().unstructure(new_videos) - - with h5.File(filename, "a") as f: - - # Add all the JSON metadata - meta_group = f.require_group("metadata") - - # If we are appending and there already exists JSON metadata - if append and "json" in meta_group.attrs: - - # Otherwise, we need to read the JSON and append to the lists - old_labels = Labels.from_json( - meta_group.attrs["json"].tostring().decode() - ) - - # A function to join to list but only include new non-dupe entries - # from the right hand list. - def append_unique(old, new): - unique = [] - for x in new: - try: - matches = [y.matches(x) for y in old] - except AttributeError: - matches = [x == y for y in old] - - # If there were no matches, this is a unique object. - if sum(matches) == 0: - unique.append(x) - else: - # If we have an object that matches, replace the instance with - # the one from the new list. This will will make sure objects - # on the Instances are the same as those in the Labels lists. - for i, match in enumerate(matches): - if match: - old[i] = x - - return old + unique - - # Append the lists - labels.tracks = append_unique(old_labels.tracks, labels.tracks) - labels.skeletons = append_unique(old_labels.skeletons, labels.skeletons) - labels.videos = append_unique(old_labels.videos, labels.videos) - labels.nodes = append_unique(old_labels.nodes, labels.nodes) - - # FIXME: Do something for suggestions and negative_anchors - - # Get the dict for JSON and save it over the old data - d = labels.to_dict(skip_labels=True) - - if not append: - for key in ("videos", "tracks", "suggestions"): - - # Convert for saving in hdf5 dataset - data = [np.string_(json_dumps(item)) for item in d[key]] - - hdf5_key = f"{key}_json" - - # Save in its own dataset (e.g., videos_json) - f.create_dataset(hdf5_key, data=data, maxshape=(None,)) - - # Clear from dict since we don't want to save this in attribute - d[key] = [] - - # Output the dict to JSON - meta_group.attrs["json"] = np.string_(json_dumps(d)) - - # FIXME: We can probably construct these from attrs fields - # We will store Instances and PredcitedInstances in the same - # table. instance_type=0 or Instance and instance_type=1 for - # PredictedInstance, score will be ignored for Instances. - instance_dtype = np.dtype( - [ - ("instance_id", "i8"), - ("instance_type", "u1"), - ("frame_id", "u8"), - ("skeleton", "u4"), - ("track", "i4"), - ("from_predicted", "i8"), - ("score", "f4"), - ("point_id_start", "u8"), - ("point_id_end", "u8"), - ] - ) - frame_dtype = np.dtype( - [ - ("frame_id", "u8"), - ("video", "u4"), - ("frame_idx", "u8"), - ("instance_id_start", "u8"), - ("instance_id_end", "u8"), - ] - ) + # Make sure that all directories for path exist + os.makedirs(os.path.dirname(filename), exist_ok=True) - num_instances = len(labels.all_instances) - max_skeleton_size = max([len(s.nodes) for s in labels.skeletons], default=0) + # Detect filetype and use appropriate save method + # if not filename.endswith((".json", ".zip", ".h5")) and default_suffix: + # filename += f".{default_suffix}" - # Initialize data arrays for serialization - points = np.zeros(num_instances * max_skeleton_size, dtype=Point.dtype) - pred_points = np.zeros( - num_instances * max_skeleton_size, dtype=PredictedPoint.dtype - ) - instances = np.zeros(num_instances, dtype=instance_dtype) - frames = np.zeros(len(labels), dtype=frame_dtype) + from .format import write - # Pre compute some structures to make serialization faster - skeleton_to_idx = { - skeleton: labels.skeletons.index(skeleton) - for skeleton in labels.skeletons - } - track_to_idx = { - track: labels.tracks.index(track) for track in labels.tracks - } - track_to_idx[None] = -1 - video_to_idx = { - video: labels.videos.index(video) for video in labels.videos - } - instance_type_to_idx = {Instance: 0, PredictedInstance: 1} - - # Each instance we create will have and index in the dataset, keep track of - # these so we can quickly add from_predicted links on a second pass. - instance_to_idx = {} - instances_with_from_predicted = [] - instances_from_predicted = [] - - # If we are appending, we need look inside to see what frame, instance, and point - # ids we need to start from. This gives us offsets to use. - if append and "points" in f: - point_id_offset = f["points"].shape[0] - pred_point_id_offset = f["pred_points"].shape[0] - instance_id_offset = f["instances"][-1]["instance_id"] + 1 - frame_id_offset = int(f["frames"][-1]["frame_id"]) + 1 - else: - point_id_offset = 0 - pred_point_id_offset = 0 - instance_id_offset = 0 - frame_id_offset = 0 - - point_id = 0 - pred_point_id = 0 - instance_id = 0 - - for frame_id, label in enumerate(labels): - frames[frame_id] = ( - frame_id + frame_id_offset, - video_to_idx[label.video], - label.frame_idx, - instance_id + instance_id_offset, - instance_id + instance_id_offset + len(label.instances), - ) - for instance in label.instances: - - # Add this instance to our lookup structure we will need for from_predicted - # links - instance_to_idx[instance] = instance_id - - parray = instance.get_points_array(copy=False, full=True) - instance_type = type(instance) - - # Check whether we are working with a PredictedInstance or an Instance. - if instance_type is PredictedInstance: - score = instance.score - pid = pred_point_id + pred_point_id_offset - else: - score = np.nan - pid = point_id + point_id_offset - - # Keep track of any from_predicted instance links, we will insert the - # correct instance_id in the dataset after we are done. - if instance.from_predicted: - instances_with_from_predicted.append(instance_id) - instances_from_predicted.append(instance.from_predicted) - - # Copy all the data - instances[instance_id] = ( - instance_id + instance_id_offset, - instance_type_to_idx[instance_type], - frame_id, - skeleton_to_idx[instance.skeleton], - track_to_idx[instance.track], - -1, - score, - pid, - pid + len(parray), - ) - - # If these are predicted points, copy them to the predicted point array - # otherwise, use the normal point array - if type(parray) is PredictedPointArray: - pred_points[ - pred_point_id : pred_point_id + len(parray) - ] = parray - pred_point_id = pred_point_id + len(parray) - else: - points[point_id : point_id + len(parray)] = parray - point_id = point_id + len(parray) - - instance_id = instance_id + 1 - - # Add from_predicted links - for instance_id, from_predicted in zip( - instances_with_from_predicted, instances_from_predicted - ): - try: - instances[instance_id]["from_predicted"] = instance_to_idx[ - from_predicted - ] - except KeyError: - # If we haven't encountered the from_predicted instance yet then don't save the link. - # It’s possible for a user to create a regular instance from a predicted instance and then - # delete all predicted instances from the file, but in this case I don’t think there’s any reason - # to remember which predicted instance the regular instance came from. - pass - - # We pre-allocated our points array with max possible size considering the max - # skeleton size, drop any unused points. - points = points[0:point_id] - pred_points = pred_points[0:pred_point_id] - - # Create datasets if we need to - if append and "points" in f: - f["points"].resize((f["points"].shape[0] + points.shape[0]), axis=0) - f["points"][-points.shape[0] :] = points - f["pred_points"].resize( - (f["pred_points"].shape[0] + pred_points.shape[0]), axis=0 - ) - f["pred_points"][-pred_points.shape[0] :] = pred_points - f["instances"].resize( - (f["instances"].shape[0] + instances.shape[0]), axis=0 - ) - f["instances"][-instances.shape[0] :] = instances - f["frames"].resize((f["frames"].shape[0] + frames.shape[0]), axis=0) - f["frames"][-frames.shape[0] :] = frames - else: - f.create_dataset( - "points", data=points, maxshape=(None,), dtype=Point.dtype - ) - f.create_dataset( - "pred_points", - data=pred_points, - maxshape=(None,), - dtype=PredictedPoint.dtype, - ) - f.create_dataset( - "instances", data=instances, maxshape=(None,), dtype=instance_dtype - ) - f.create_dataset( - "frames", data=frames, maxshape=(None,), dtype=frame_dtype - ) + write(filename, labels, *args, **kwargs) @classmethod - def load_hdf5( - cls, filename: str, video_callback=None, match_to: Optional["Labels"] = None - ): - """ - Deserialize HDF5 file as new :class:`Labels` instance. - - Args: - filename: Path to HDF5 file. - video_callback: A callback function that which can modify - video paths before we try to create the corresponding - :class:`Video` objects. Usually you'll want to pass - a callback created by :meth:`make_video_callback` - or :meth:`make_gui_video_callback`. - Alternately, if you pass a list of strings we'll construct a - non-gui callback with those strings as the search paths. - match_to: If given, we'll replace particular objects in the - data dictionary with *matching* objects in the match_to - :class:`Labels` object. This ensures that the newly - instantiated :class:`Labels` can be merged without - duplicate matching objects (e.g., :class:`Video` objects ). + def load_json(cls, filename: str, *args, **kwargs) -> "Labels": + from .format import read - Returns: - A new :class:`Labels` object. - """ - with h5.File(filename, "r") as f: + return read(filename, for_object="labels", as_format="json", *args, **kwargs) - # Extract the Labels JSON metadata and create Labels object with just - # this metadata. - dicts = json_loads( - f.require_group("metadata").attrs["json"].tostring().decode() - ) + @classmethod + def save_json(cls, labels: "Labels", filename: str, *args, **kwargs): + from .format import write - for key in ("videos", "tracks", "suggestions"): - hdf5_key = f"{key}_json" - if hdf5_key in f: - items = [json_loads(item_json) for item_json in f[hdf5_key]] - dicts[key] = items - - # Video path "." means the video is saved in same file as labels, - # so replace these paths. - for video_item in dicts["videos"]: - if video_item["backend"]["filename"] == ".": - video_item["backend"]["filename"] = filename - - if hasattr(video_callback, "__iter__"): - # If the callback is an iterable, then we'll expect it to be a - # list of strings and build a non-gui callback with those as - # the search paths. - search_paths = [path for path in video_callback] - video_callback = cls.make_video_callback(search_paths) - - # Use the callback if given to handle missing videos - if callable(video_callback): - video_callback(dicts["videos"]) - - labels = cls.from_json(dicts, match_to=match_to) - - frames_dset = f["frames"][:] - instances_dset = f["instances"][:] - points_dset = f["points"][:] - pred_points_dset = f["pred_points"][:] - - # Rather than instantiate a bunch of Point\PredictedPoint objects, we will - # use inplace numpy recarrays. This will save a lot of time and memory - # when reading things in. - points = PointArray(buf=points_dset, shape=len(points_dset)) - pred_points = PredictedPointArray( - buf=pred_points_dset, shape=len(pred_points_dset) - ) + write(filename, labels, as_format="json", *args, **kwargs) - # Extend the tracks list with a None track. We will signify this with a -1 in the - # data which will map to last element of tracks - tracks = labels.tracks.copy() - tracks.extend([None]) - - # A dict to keep track of instances that have a from_predicted link. The key is the - # instance and the value is the index of the instance. - from_predicted_lookup = {} - - # Create the instances - instances = [] - for i in instances_dset: - track = tracks[i["track"]] - skeleton = labels.skeletons[i["skeleton"]] - - if i["instance_type"] == 0: # Instance - instance = Instance( - skeleton=skeleton, - track=track, - points=points[i["point_id_start"] : i["point_id_end"]], - ) - else: # PredictedInstance - instance = PredictedInstance( - skeleton=skeleton, - track=track, - points=pred_points[i["point_id_start"] : i["point_id_end"]], - score=i["score"], - ) - instances.append(instance) - - if i["from_predicted"] != -1: - from_predicted_lookup[instance] = i["from_predicted"] - - # Make a second pass to add any from_predicted links - for instance, from_predicted_idx in from_predicted_lookup.items(): - instance.from_predicted = instances[from_predicted_idx] - - # Create the labeled frames - frames = [ - LabeledFrame( - video=labels.videos[frame["video"]], - frame_idx=frame["frame_idx"], - instances=instances[ - frame["instance_id_start"] : frame["instance_id_end"] - ], - ) - for i, frame in enumerate(frames_dset) - ] + @classmethod + def load_hdf5(cls, filename, *args, **kwargs): + from .format import read - labels.labeled_frames = frames + return read(filename, for_object="labels", as_format="hdf5_v1", *args, **kwargs) - # Do the stuff that should happen after we have labeled frames - labels._build_lookup_caches() + @classmethod + def save_hdf5(cls, labels, filename, *args, **kwargs): + from .format import write - return labels + write(filename, labels, as_format="hdf5_v1", *args, **kwargs) @classmethod - def load_file(cls, filename: str, *args, **kwargs): - """Load file, detecting format from filename.""" - if filename.endswith((".h5", ".hdf5")): - return cls.load_hdf5(filename, *args, **kwargs) - elif filename.endswith((".json", ".json.zip")): - return cls.load_json(filename, *args, **kwargs) - elif filename.endswith(".csv"): - # for now, the only csv we support is the DeepLabCut format - return cls.load_deeplabcut_csv(filename) - else: - raise ValueError(f"Cannot detect filetype for {filename}") + def load_leap_matlab(cls, filename, *args, **kwargs): + from .format import read + + return read(filename, for_object="labels", as_format="leap", *args, **kwargs) @classmethod - def save_file( - cls, labels: "Labels", filename: str, default_suffix: str = "", *args, **kwargs - ): - """Save file, detecting format from filename. + def load_deeplabcut_csv(cls, filename: str) -> "Labels": + from sleap.io.format.deeplabcut import LabelsDeepLabCutAdaptor + from sleap.io.format.filehandle import FileHandle - Args: - labels: The dataset to save. - filename: Path where we'll save it. We attempt to detect format - from the suffix (e.g., ".json"). - default_suffix: If we can't detect valid suffix on filename, - we can add default suffix to filename (and use corresponding - format). Doesn't need to have "." before file extension. + return LabelsDeepLabCutAdaptor.read(FileHandle(filename)) - Raises: - ValueError: If cannot detect valid filetype. + @classmethod + def load_coco( + cls, filename: str, img_dir: str, use_missing_gui: bool = False, + ) -> "Labels": + from sleap.io.format.coco import LabelsCocoAdaptor + from sleap.io.format.filehandle import FileHandle - Returns: - None. - """ - # Convert to full (absolute) path - filename = os.path.abspath(filename) + return LabelsCocoAdaptor.read(FileHandle(filename), img_dir, use_missing_gui) - # Make sure that all directories for path exist - os.makedirs(os.path.dirname(filename), exist_ok=True) + @classmethod + def from_deepposekit( + cls, filename: str, video_path: str, skeleton_path: str + ) -> "Labels": + from sleap.io.format.deepposekit import LabelsDeepPoseKitAdaptor + from sleap.io.format.filehandle import FileHandle - # Detect filetype and use appropriate save method - if not filename.endswith((".json", ".zip", ".h5")) and default_suffix: - filename += f".{default_suffix}" - if filename.endswith((".json", ".zip")): - compress = filename.endswith(".zip") - cls.save_json(labels=labels, filename=filename, compress=compress, **kwargs) - elif filename.endswith(".h5"): - cls.save_hdf5(labels=labels, filename=filename, **kwargs) - else: - raise ValueError(f"Cannot detect filetype for {filename}") + return LabelsDeepPoseKitAdaptor.read( + FileHandle(filename), video_path, skeleton_path + ) def save_frame_data_imgstore( self, output_dir: str = "./", format: str = "png", all_labels: bool = False @@ -2102,365 +1317,6 @@ def save_frame_data_hdf5( return new_vids - @staticmethod - def _unwrap_mat_scalar(a): - """Extract single value from nested MATLAB file data.""" - if a.shape == (1,): - return Labels._unwrap_mat_scalar(a[0]) - else: - return a - - @staticmethod - def _unwrap_mat_array(a): - """Extract list of values from nested MATLAB file data.""" - b = a[0][0] - c = [Labels._unwrap_mat_scalar(x) for x in b] - return c - - @classmethod - def load_leap_matlab(cls, filename: str, gui: bool = True) -> "Labels": - """Load LEAP MATLAB file as dataset. - - Args: - filename: Path to matlab file. - Returns: - The :class:`Labels` dataset. - """ - mat_contents = sio.loadmat(filename) - - box_path = Labels._unwrap_mat_scalar(mat_contents["boxPath"]) - - # If the video file isn't found, try in the same dir as the mat file - if not os.path.exists(box_path): - file_dir = os.path.dirname(filename) - box_path_name = box_path.split("\\")[-1] # assume windows path - box_path = os.path.join(file_dir, box_path_name) - - if not os.path.exists(box_path): - if gui: - video_paths = [box_path] - missing = [True] - okay = MissingFilesDialog(video_paths, missing).exec_() - - if not okay or missing[0]: - return - - box_path = video_paths[0] - else: - # Ignore missing videos if not loading from gui - box_path = "" - - if os.path.exists(box_path): - vid = Video.from_hdf5( - dataset="box", filename=box_path, input_format="channels_first" - ) - else: - vid = None - - nodes_ = mat_contents["skeleton"]["nodes"] - edges_ = mat_contents["skeleton"]["edges"] - points_ = mat_contents["positions"] - - edges_ = edges_ - 1 # convert matlab 1-indexing to python 0-indexing - - nodes = Labels._unwrap_mat_array(nodes_) - edges = Labels._unwrap_mat_array(edges_) - - nodes = list(map(str, nodes)) # convert np._str to str - - sk = Skeleton(name=filename) - sk.add_nodes(nodes) - for edge in edges: - sk.add_edge(source=nodes[edge[0]], destination=nodes[edge[1]]) - - labeled_frames = [] - node_count, _, frame_count = points_.shape - - for i in range(frame_count): - new_inst = Instance(skeleton=sk) - for node_idx, node in enumerate(nodes): - x = points_[node_idx][0][i] - y = points_[node_idx][1][i] - new_inst[node] = Point(x, y) - if len(new_inst.points): - new_frame = LabeledFrame(video=vid, frame_idx=i) - new_frame.instances = (new_inst,) - labeled_frames.append(new_frame) - - labels = cls(labeled_frames=labeled_frames, videos=[vid], skeletons=[sk]) - - return labels - - @classmethod - def load_deeplabcut_csv(cls, filename: str) -> "Labels": - """Load DeepLabCut csv file as dataset. - - Args: - filename: Path to csv file. - Returns: - The :class:`Labels` dataset. - """ - - # At the moment we don't need anything from the config file, - # but the code to read it is here in case we do in the future. - - # # Try to find the config file by walking up file path starting at csv file looking for config.csv - # last_dir = None - # file_dir = os.path.dirname(filename) - # config_filename = "" - - # while file_dir != last_dir: - # last_dir = file_dir - # file_dir = os.path.dirname(file_dir) - # config_filename = os.path.join(file_dir, 'config.yaml') - # if os.path.exists(config_filename): - # break - - # # If we couldn't find a config file, give up - # if not os.path.exists(config_filename): return - - # with open(config_filename, 'r') as f: - # config = yaml.load(f, Loader=yaml.SafeLoader) - - # x1 = config['x1'] - # y1 = config['y1'] - # x2 = config['x2'] - # y2 = config['y2'] - - data = pd.read_csv(filename, header=[1, 2]) - - # Create the skeleton from the list of nodes in the csv file - # Note that DeepLabCut doesn't have edges, so these will have to be added by user later - node_names = [n[0] for n in list(data)[1::2]] - - skeleton = Skeleton() - skeleton.add_nodes(node_names) - - # Create an imagestore `Video` object from frame images. - # This may not be ideal for large projects, since we're reading in - # each image and then writing it out in a new directory. - - img_files = data.ix[:, 0] # get list of all images - - # the image filenames in the csv may not match where the user has them - # so we'll change the directory to match where the user has the csv - def fix_img_path(img_dir, img_filename): - img_filename = os.path.basename(img_filename) - img_filename = os.path.join(img_dir, img_filename) - return img_filename - - img_dir = os.path.dirname(filename) - img_files = list(map(lambda f: fix_img_path(img_dir, f), img_files)) - - # we'll put the new imgstore in the same directory as the current csv - imgstore_name = os.path.join(os.path.dirname(filename), "sleap_video") - - # create the imgstore (or open if it already exists) - if os.path.exists(imgstore_name): - video = Video.from_filename(imgstore_name) - else: - video = Video.imgstore_from_filenames(img_files, imgstore_name) - - labels = [] - - for i in range(len(data)): - # get points for each node - instance_points = dict() - for node in node_names: - x, y = data[(node, "x")][i], data[(node, "y")][i] - instance_points[node] = Point(x, y) - # create instance with points (we can assume there's only one instance per frame) - instance = Instance(skeleton=skeleton, points=instance_points) - # create labeledframe and add it to list - label = LabeledFrame(video=video, frame_idx=i, instances=[instance]) - labels.append(label) - - return cls(labels) - - @classmethod - def load_coco( - cls, filename: str, img_dir: str, use_missing_gui: bool = False - ) -> "Labels": - with open(filename, "r") as file: - json_str = file.read() - dicts = json_loads(json_str) - - # Make skeletons from "categories" - skeleton_map = dict() - for category in dicts["categories"]: - skeleton = Skeleton(name=category["name"]) - skeleton_id = category["id"] - node_names = category["keypoints"] - skeleton.add_nodes(node_names) - - try: - for src_idx, dst_idx in category["skeleton"]: - skeleton.add_edge(node_names[src_idx], node_names[dst_idx]) - except IndexError as e: - # According to the COCO data format specifications[^1], the edges - # are supposed to be 1-indexed. But in some of their own - # dataset the edges are 1-indexed! So we'll try. - # [1]: http://cocodataset.org/#format-data - - # Clear any edges we already created using 0-indexing - skeleton.clear_edges() - - # Add edges - for src_idx, dst_idx in category["skeleton"]: - skeleton.add_edge(node_names[src_idx - 1], node_names[dst_idx - 1]) - - skeleton_map[skeleton_id] = skeleton - - # Make videos from "images" - - # Remove images that aren't referenced in the annotations - img_refs = [annotation["image_id"] for annotation in dicts["annotations"]] - dicts["images"] = list(filter(lambda im: im["id"] in img_refs, dicts["images"])) - - # Key in JSON file should be "file_name", but sometimes it's "filename", - # so we have to check both. - img_filename_key = "file_name" - if img_filename_key not in dicts["images"][0].keys(): - img_filename_key = "filename" - - # First add the img_dir to each image filename - img_paths = [ - os.path.join(img_dir, image[img_filename_key]) for image in dicts["images"] - ] - - # See if there are any missing files - img_missing = [not os.path.exists(path) for path in img_paths] - - if sum(img_missing): - if use_missing_gui: - okay = MissingFilesDialog(img_paths, img_missing).exec_() - - if not okay: - return None - else: - raise FileNotFoundError( - f"Images for COCO dataset could not be found in {img_dir}." - ) - - # Update the image paths (with img_dir or user selected path) - for image, path in zip(dicts["images"], img_paths): - image[img_filename_key] = path - - # Create the video objects for the image files - image_video_map = dict() - - vid_id_video_map = dict() - for image in dicts["images"]: - image_id = image["id"] - image_filename = image[img_filename_key] - - # Sometimes images have a vid_id which links multiple images - # together as one video. If so, we'll use that as the video key. - # But if there isn't a vid_id, we'll treat each images as a - # distinct video and use the image id as the video id. - vid_id = image.get("vid_id", image_id) - - if vid_id not in vid_id_video_map: - kwargs = dict(filenames=[image_filename]) - for key in ("width", "height"): - if key in image: - kwargs[key] = image[key] - - video = Video.from_image_filenames(**kwargs) - vid_id_video_map[vid_id] = video - frame_idx = 0 - else: - video = vid_id_video_map[vid_id] - frame_idx = video.num_frames - video.backend.filenames.append(image_filename) - - image_video_map[image_id] = (video, frame_idx) - - # Make instances from "annotations" - lf_map = dict() - track_map = dict() - for annotation in dicts["annotations"]: - skeleton = skeleton_map[annotation["category_id"]] - image_id = annotation["image_id"] - video, frame_idx = image_video_map[image_id] - keypoints = np.array(annotation["keypoints"], dtype="int").reshape(-1, 3) - - track = None - if "track_id" in annotation: - track_id = annotation["track_id"] - if track_id not in track_map: - track_map[track_id] = Track(frame_idx, str(track_id)) - track = track_map[track_id] - - points = dict() - any_visible = False - for i in range(len(keypoints)): - node = skeleton.nodes[i] - x, y, flag = keypoints[i] - - if flag == 0: - # node not labeled for this instance - continue - - is_visible = flag == 2 - any_visible = any_visible or is_visible - points[node] = Point(x, y, is_visible) - - if points: - # If none of the points had 2 has the "visible" flag, we'll - # assume this incorrect and just mark all as visible. - if not any_visible: - for point in points.values(): - point.visible = True - - inst = Instance(skeleton=skeleton, points=points, track=track) - - if image_id not in lf_map: - lf_map[image_id] = LabeledFrame(video, frame_idx) - - lf_map[image_id].insert(0, inst) - - return cls(labeled_frames=list(lf_map.values())) - - @classmethod - def from_deepposekit(cls, filename: str, video_path: str, skeleton_path: str): - video = Video.from_filename(video_path) - - skeleton_data = pd.read_csv(skeleton_path, header=0) - skeleton = Skeleton() - skeleton.add_nodes(skeleton_data["name"]) - nodes = skeleton.nodes - - for name, parent, swap in skeleton_data.itertuples(index=False, name=None): - if parent is not np.nan: - skeleton.add_edge(parent, name) - - lfs = [] - with h5.File(filename, "r") as f: - pose_matrix = f["pose"][:] - - track_count, frame_count, node_count, _ = pose_matrix.shape - - tracks = [Track(0, f"Track {i}") for i in range(track_count)] - for frame_idx in range(frame_count): - lf_instances = [] - for track_idx in range(track_count): - points_array = pose_matrix[track_idx, frame_idx, :, :] - points = dict() - for p in range(len(points_array)): - x, y, score = points_array[p] - points[nodes[p]] = Point(x, y) # TODO: score - - inst = Instance( - skeleton=skeleton, track=tracks[track_idx], points=points - ) - lf_instances.append(inst) - lfs.append( - LabeledFrame(video, frame_idx=frame_idx, instances=lf_instances) - ) - - return cls(labeled_frames=lfs) - @classmethod def make_video_callback(cls, search_paths: Optional[List] = None) -> Callable: """ diff --git a/sleap/io/format/__init__.py b/sleap/io/format/__init__.py new file mode 100644 index 000000000..0ecfacd33 --- /dev/null +++ b/sleap/io/format/__init__.py @@ -0,0 +1,101 @@ +from .coco import LabelsCocoAdaptor +from .deeplabcut import LabelsDeepLabCutAdaptor +from .deepposekit import LabelsDeepPoseKitAdaptor +from .hdf5 import LabelsV1Adaptor +from .labels_json import LabelsJsonAdaptor +from .leap_matlab import LabelsLeapMatlabAdaptor + +from . import adaptor, dispatch, filehandle + +from typing import Text, Optional, Union + +default_labels_adaptors = [LabelsV1Adaptor, LabelsJsonAdaptor] + +all_labels_adaptors = { + "hdf5_v1": LabelsV1Adaptor, + "json": LabelsJsonAdaptor, + "leap": LabelsLeapMatlabAdaptor, + "deeplabcut": LabelsDeepLabCutAdaptor, + "deepposekit": LabelsDeepPoseKitAdaptor, + "coco": LabelsCocoAdaptor, +} + + +def read( + filename: Text, + for_object: Union[Text, object], + as_format: Optional[Text] = None, + *args, + **kwargs, +): + """ + Reads file using the appropriate file format adaptor. + + Args: + filename: Full filename of the file to read. + for_object: The type of object we're trying to read; can be given as + string (e.g., "labels") or instance of the object. + as_format: Allows you to specify the format adaptor to use; + if not specified, then we'll try the default adaptors for this + object type. + + Exceptions: + NotImplementedError if appropriate adaptor cannot be found. + TypeError if adaptor does not support reading + (shouldn't happen unless you specify `as_format` adaptor). + Any file-related exception thrown while trying to read. + """ + + disp = dispatch.Dispatch() + + if as_format in all_labels_adaptors: + disp.register(all_labels_adaptors[as_format]) + return disp.read(filename, *args, **kwargs) + + if for_object == "labels" or hasattr(for_object, "labeled_frames"): + disp.register_list(default_labels_adaptors) + return disp.read(filename, *args, **kwargs) + + raise NotImplementedError("No adaptors for this object type.") + + +def write( + filename: str, + source_object: object, + as_format: Optional[Text] = None, + *args, + **kwargs, +): + """ + Writes file using the appropriate file format adaptor. + + Args: + filename: Full filename of the file to write. + All directories should exist. + source_object: The object we want to write to a file. + as_format: Allows you to specify the format adaptor to use; + if not specified, then this will use the privileged adaptor for + the type of object. + + Exceptions: + NotImplementedError if appropriate adaptor cannot be found. + TypeError if adaptor does not support writing + (shouldn't happen unless you specify `as_format` adaptor). + Any file-related exception thrown while trying to write. + """ + disp = dispatch.Dispatch() + + if as_format in all_labels_adaptors: + disp.register(all_labels_adaptors[as_format]) + return disp.write(filename, source_object, *args, **kwargs) + + elif as_format is not None: + raise KeyError(f"No adaptor for {as_format}.") + + if hasattr(source_object, "labeled_frames"): + disp.register_list(default_labels_adaptors) + return disp.write(filename, source_object, *args, **kwargs) + + raise NotImplementedError( + f"No adaptors for object type {type(source_object)} ({as_format})." + ) diff --git a/sleap/io/format/adaptor.py b/sleap/io/format/adaptor.py new file mode 100644 index 000000000..9a47564bb --- /dev/null +++ b/sleap/io/format/adaptor.py @@ -0,0 +1,75 @@ +import os +from enum import Enum +from typing import List + +import attr + +from sleap.io.format.filehandle import FileHandle + + +class SleapObjectType(Enum): + misc = 0 + labels = 1 + + +@attr.s(auto_attribs=True) +class Adaptor(object): + """ + Abstract base class which defines interface for file format adaptors. + """ + + @property + def handles(self) -> SleapObjectType: + """Returns the type of object that can be read/written.""" + raise NotImplementedError + + @property + def default_ext(self) -> str: + raise NotImplementedError + + @property + def all_exts(self) -> List[str]: + raise NotImplementedError + + @property + def name(self) -> str: + raise NotImplementedError + + def can_read_file(self, file: FileHandle) -> bool: + """Returns whether this adaptor can read this file.""" + raise NotImplementedError + + def can_write_filename(self, filename: str) -> bool: + """Returns whether this adaptor can write format of this filename.""" + raise NotImplementedError + + def does_read(self) -> bool: + """Returns whether this adaptor supports reading.""" + raise NotImplementedError + + def does_write(self) -> bool: + """Returns whether this adaptor supports writing.""" + raise NotImplementedError + + def read(self, file: FileHandle) -> object: + """Reads the file and returns the appropriate deserialized object.""" + raise NotImplementedError + + def write(self, filename: str, source_object: object): + """Writes the object to a file.""" + raise NotImplementedError + + # Methods with default implementation + + def does_match_ext(self, filename: str) -> bool: + """Returns whether this adaptor can write format of this filename.""" + + # We don't match the ext against the result of os.path.splitext because + # we want to match extensions like ".json.zip". + + return filename.endswith(tuple(self.all_exts)) + + @property + def formatted_ext_options(self): + """String for Qt file dialog extension options.""" + return f"{self.name} ({' '.join(self.all_exts)})" diff --git a/sleap/io/format/coco.py b/sleap/io/format/coco.py new file mode 100644 index 000000000..5a4020536 --- /dev/null +++ b/sleap/io/format/coco.py @@ -0,0 +1,197 @@ +import os + +import numpy as np + +from sleap import Labels, Video, Skeleton +from sleap.gui.missingfiles import MissingFilesDialog +from sleap.instance import Instance, LabeledFrame, Point, Track + +from .adaptor import Adaptor, SleapObjectType +from .filehandle import FileHandle + + +class LabelsCocoAdaptor(Adaptor): + @property + def handles(self): + return SleapObjectType.labels + + @property + def default_ext(self): + return "json" + + @property + def all_exts(self): + return ["json"] + + @property + def name(self): + return "COCO Dataset JSON" + + def can_read_file(self, file: FileHandle): + if not self.does_match_ext(file.filename): + return False + if not file.is_json: + return False + if "annotations" not in file.json: + return False + if "categories" not in file.json: + return False + return True + + def can_write_filename(self, filename: str): + return False + + def does_read(self) -> bool: + return True + + def does_write(self) -> bool: + return False + + @classmethod + def read( + cls, + file: FileHandle, + img_dir: str, + use_missing_gui: bool = False, + *args, + **kwargs, + ) -> Labels: + + dicts = file.json + + # Make skeletons from "categories" + skeleton_map = dict() + for category in dicts["categories"]: + skeleton = Skeleton(name=category["name"]) + skeleton_id = category["id"] + node_names = category["keypoints"] + skeleton.add_nodes(node_names) + + try: + for src_idx, dst_idx in category["skeleton"]: + skeleton.add_edge(node_names[src_idx], node_names[dst_idx]) + except IndexError as e: + # According to the COCO data format specifications[^1], the edges + # are supposed to be 1-indexed. But in some of their own + # dataset the edges are 1-indexed! So we'll try. + # [1]: http://cocodataset.org/#format-data + + # Clear any edges we already created using 0-indexing + skeleton.clear_edges() + + # Add edges + for src_idx, dst_idx in category["skeleton"]: + skeleton.add_edge(node_names[src_idx - 1], node_names[dst_idx - 1]) + + skeleton_map[skeleton_id] = skeleton + + # Make videos from "images" + + # Remove images that aren't referenced in the annotations + img_refs = [annotation["image_id"] for annotation in dicts["annotations"]] + dicts["images"] = list(filter(lambda im: im["id"] in img_refs, dicts["images"])) + + # Key in JSON file should be "file_name", but sometimes it's "filename", + # so we have to check both. + img_filename_key = "file_name" + if img_filename_key not in dicts["images"][0].keys(): + img_filename_key = "filename" + + # First add the img_dir to each image filename + img_paths = [ + os.path.join(img_dir, image[img_filename_key]) for image in dicts["images"] + ] + + # See if there are any missing files + img_missing = [not os.path.exists(path) for path in img_paths] + + if sum(img_missing): + if use_missing_gui: + okay = MissingFilesDialog(img_paths, img_missing).exec_() + + if not okay: + return None + else: + raise FileNotFoundError( + f"Images for COCO dataset could not be found in {img_dir}." + ) + + # Update the image paths (with img_dir or user selected path) + for image, path in zip(dicts["images"], img_paths): + image[img_filename_key] = path + + # Create the video objects for the image files + image_video_map = dict() + + vid_id_video_map = dict() + for image in dicts["images"]: + image_id = image["id"] + image_filename = image[img_filename_key] + + # Sometimes images have a vid_id which links multiple images + # together as one video. If so, we'll use that as the video key. + # But if there isn't a vid_id, we'll treat each images as a + # distinct video and use the image id as the video id. + vid_id = image.get("vid_id", image_id) + + if vid_id not in vid_id_video_map: + kwargs = dict(filenames=[image_filename]) + for key in ("width", "height"): + if key in image: + kwargs[key] = image[key] + + video = Video.from_image_filenames(**kwargs) + vid_id_video_map[vid_id] = video + frame_idx = 0 + else: + video = vid_id_video_map[vid_id] + frame_idx = video.num_frames + video.backend.filenames.append(image_filename) + + image_video_map[image_id] = (video, frame_idx) + + # Make instances from "annotations" + lf_map = dict() + track_map = dict() + for annotation in dicts["annotations"]: + skeleton = skeleton_map[annotation["category_id"]] + image_id = annotation["image_id"] + video, frame_idx = image_video_map[image_id] + keypoints = np.array(annotation["keypoints"], dtype="int").reshape(-1, 3) + + track = None + if "track_id" in annotation: + track_id = annotation["track_id"] + if track_id not in track_map: + track_map[track_id] = Track(frame_idx, str(track_id)) + track = track_map[track_id] + + points = dict() + any_visible = False + for i in range(len(keypoints)): + node = skeleton.nodes[i] + x, y, flag = keypoints[i] + + if flag == 0: + # node not labeled for this instance + continue + + is_visible = flag == 2 + any_visible = any_visible or is_visible + points[node] = Point(x, y, is_visible) + + if points: + # If none of the points had 2 has the "visible" flag, we'll + # assume this incorrect and just mark all as visible. + if not any_visible: + for point in points.values(): + point.visible = True + + inst = Instance(skeleton=skeleton, points=points, track=track) + + if image_id not in lf_map: + lf_map[image_id] = LabeledFrame(video, frame_idx) + + lf_map[image_id].insert(0, inst) + + return Labels(labeled_frames=list(lf_map.values())) diff --git a/sleap/io/format/deeplabcut.py b/sleap/io/format/deeplabcut.py new file mode 100644 index 000000000..81563003f --- /dev/null +++ b/sleap/io/format/deeplabcut.py @@ -0,0 +1,122 @@ +import os + +import pandas as pd + +from sleap import Labels, Video, Skeleton +from sleap.instance import Instance, LabeledFrame, Point + +from .adaptor import Adaptor, SleapObjectType +from .filehandle import FileHandle + + +class LabelsDeepLabCutAdaptor(Adaptor): + @property + def handles(self): + return SleapObjectType.labels + + @property + def default_ext(self): + return "csv" + + @property + def all_exts(self): + return ["csv"] + + @property + def name(self): + return "DeepLabCut Dataset CSV" + + def can_read_file(self, file: FileHandle): + if not self.does_match_ext(file.filename): + return False + # TODO: add checks for valid deeplabcut csv + return True + + def can_write_filename(self, filename: str): + return False + + def does_read(self) -> bool: + return True + + def does_write(self) -> bool: + return False + + @classmethod + def read(cls, file: FileHandle, *args, **kwargs,) -> Labels: + filename = file.filename + + # At the moment we don't need anything from the config file, + # but the code to read it is here in case we do in the future. + + # # Try to find the config file by walking up file path starting at csv file looking for config.csv + # last_dir = None + # file_dir = os.path.dirname(filename) + # config_filename = "" + + # while file_dir != last_dir: + # last_dir = file_dir + # file_dir = os.path.dirname(file_dir) + # config_filename = os.path.join(file_dir, 'config.yaml') + # if os.path.exists(config_filename): + # break + + # # If we couldn't find a config file, give up + # if not os.path.exists(config_filename): return + + # with open(config_filename, 'r') as f: + # config = yaml.load(f, Loader=yaml.SafeLoader) + + # x1 = config['x1'] + # y1 = config['y1'] + # x2 = config['x2'] + # y2 = config['y2'] + + data = pd.read_csv(filename, header=[1, 2]) + + # Create the skeleton from the list of nodes in the csv file + # Note that DeepLabCut doesn't have edges, so these will have to be added by user later + node_names = [n[0] for n in list(data)[1::2]] + + skeleton = Skeleton() + skeleton.add_nodes(node_names) + + # Create an imagestore `Video` object from frame images. + # This may not be ideal for large projects, since we're reading in + # each image and then writing it out in a new directory. + + img_files = data.ix[:, 0] # get list of all images + + # the image filenames in the csv may not match where the user has them + # so we'll change the directory to match where the user has the csv + def fix_img_path(img_dir, img_filename): + img_filename = os.path.basename(img_filename) + img_filename = os.path.join(img_dir, img_filename) + return img_filename + + img_dir = os.path.dirname(filename) + img_files = list(map(lambda f: fix_img_path(img_dir, f), img_files)) + + # we'll put the new imgstore in the same directory as the current csv + imgstore_name = os.path.join(os.path.dirname(filename), "sleap_video") + + # create the imgstore (or open if it already exists) + if os.path.exists(imgstore_name): + video = Video.from_filename(imgstore_name) + else: + video = Video.imgstore_from_filenames(img_files, imgstore_name) + + labels = [] + + for i in range(len(data)): + # get points for each node + instance_points = dict() + for node in node_names: + x, y = data[(node, "x")][i], data[(node, "y")][i] + instance_points[node] = Point(x, y) + # create instance with points (we can assume there's only one instance per frame) + instance = Instance(skeleton=skeleton, points=instance_points) + # create labeledframe and add it to list + label = LabeledFrame(video=video, frame_idx=i, instances=[instance]) + labels.append(label) + + return Labels(labeled_frames=labels) diff --git a/sleap/io/format/deepposekit.py b/sleap/io/format/deepposekit.py new file mode 100644 index 000000000..c727dcee1 --- /dev/null +++ b/sleap/io/format/deepposekit.py @@ -0,0 +1,86 @@ +from .adaptor import Adaptor, SleapObjectType +from .filehandle import FileHandle + +from sleap.instance import Instance, LabeledFrame, Point, Track + +from sleap import Labels, Video, Skeleton + +import numpy as np +import pandas as pd + + +class LabelsDeepPoseKitAdaptor(Adaptor): + @property + def handles(self): + return SleapObjectType.labels + + @property + def default_ext(self): + return "h5" + + @property + def all_exts(self): + return ["h5", "hdf5"] + + @property + def name(self): + return "DeepPoseKit Dataset HDF5" + + def can_read_file(self, file: FileHandle): + if not self.does_match_ext(file.filename): + return False + if not file.is_hdf5: + return False + if not hasattr(file.file, "pose"): + return False + return True + + def can_write_filename(self, filename: str): + return False + + def does_read(self) -> bool: + return True + + def does_write(self) -> bool: + return False + + @classmethod + def read( + cls, file: FileHandle, video_path: str, skeleton_path: str, *args, **kwargs, + ) -> Labels: + f = file.file + + video = Video.from_filename(video_path) + skeleton_data = pd.read_csv(skeleton_path, header=0) + + skeleton = Skeleton() + skeleton.add_nodes(skeleton_data["name"]) + nodes = skeleton.nodes + + for name, parent, swap in skeleton_data.itertuples(index=False, name=None): + if parent is not np.nan: + skeleton.add_edge(parent, name) + + lfs = [] + + pose_matrix = f["pose"][:] + + track_count, frame_count, node_count, _ = pose_matrix.shape + + tracks = [Track(0, f"Track {i}") for i in range(track_count)] + for frame_idx in range(frame_count): + lf_instances = [] + for track_idx in range(track_count): + points_array = pose_matrix[track_idx, frame_idx, :, :] + points = dict() + for p in range(len(points_array)): + x, y, score = points_array[p] + points[nodes[p]] = Point(x, y) # TODO: score + + inst = Instance( + skeleton=skeleton, track=tracks[track_idx], points=points + ) + lf_instances.append(inst) + lfs.append(LabeledFrame(video, frame_idx=frame_idx, instances=lf_instances)) + + return Labels(labeled_frames=lfs) diff --git a/sleap/io/format/dispatch.py b/sleap/io/format/dispatch.py new file mode 100644 index 000000000..b32e85dc7 --- /dev/null +++ b/sleap/io/format/dispatch.py @@ -0,0 +1,96 @@ +import attr +from typing import List, Optional, Tuple, Union + +from sleap.io.format.adaptor import Adaptor, SleapObjectType +from sleap.io.format.filehandle import FileHandle + + +@attr.s(auto_attribs=True) +class Dispatch(object): + + _adaptors: List[Adaptor] = attr.ib(default=attr.Factory(list)) + + def register(self, adaptor: Union[Adaptor, type]): + """ + Registers the class which reads/writes specific file format. + """ + # If given a class, then instantiate it since we want the object + if type(adaptor) == type: + adaptor = adaptor() + + self._adaptors.append(adaptor) + + def register_list(self, adaptor_list: List[Union[Adaptor, type]]): + for adaptor in adaptor_list: + self.register(adaptor) + + def get_formatted_ext_options(self) -> List[str]: + """ + Returns the file extensions that can be used for specified type. + + This is used for determining which extensions to list in save dialog. + """ + return [adaptor.formatted_ext_options for adaptor in self._adaptors] + + def open(self, filename: str) -> FileHandle: + """Returns FileHandle for file.""" + return FileHandle(filename) + + def read(self, filename: str, *args, **kwargs) -> object: + """Reads file and returns the deserialized object.""" + + with self.open(filename) as file: + for adaptor in self._adaptors: + if adaptor.can_read_file(file): + return adaptor.read(file, *args, **kwargs) + + raise TypeError("No file format adaptor could read this file.") + + def read_safely(self, *args, **kwargs) -> Tuple[object, Optional[BaseException]]: + """Wrapper for reading file without throwing exception.""" + try: + return self.read(*args, **kwargs), None + except Exception as e: + return None, e + + def write(self, filename: str, source_object: object, *args, **kwargs): + """ + Writes an object to a file. + + Args: + filename: The full name (including path) of the file to write. + source_object: The object to write. + """ + + for adaptor in self._adaptors: + if adaptor.can_write_filename(filename): + return adaptor.write(filename, source_object, *args, **kwargs) + + raise TypeError("No file format adaptor could write this file.") + + def write_safely(self, *args, **kwargs) -> Optional[BaseException]: + """Wrapper for writing file without throwing exception.""" + try: + self.write(*args, **kwargs) + return None + except Exception as e: + return e + + @classmethod + def make_dispatcher(cls, object_type: SleapObjectType) -> "Dispatch": + dispatcher = cls() + if object_type == SleapObjectType.labels: + from .hdf5 import LabelsV1Adaptor + from .labels_json import LabelsJsonAdaptor + from .deeplabcut import LabelsDeepLabCutAdaptor + + dispatcher.register(LabelsV1Adaptor()) + dispatcher.register(LabelsJsonAdaptor()) + dispatcher.register(LabelsDeepLabCutAdaptor()) + + elif object_type == SleapObjectType.misc: + from .text import TextAdaptor + + dispatcher.register(TextAdaptor()) + + return dispatcher diff --git a/sleap/io/format/filehandle.py b/sleap/io/format/filehandle.py new file mode 100644 index 000000000..c8ed1cb88 --- /dev/null +++ b/sleap/io/format/filehandle.py @@ -0,0 +1,93 @@ +import os +from typing import Optional + +import attr +import h5py + +from sleap.util import json_loads + + +@attr.s(auto_attribs=True) +class FileHandle(object): + """Reference to a file; can hold loaded data so it needn't be read twice.""" + + filename: str + _is_hdf5: bool = False + _is_json: Optional[bool] = None + _is_open: bool = False + _file: object = None + _text: str = None + _json: object = None + + def __enter__(self): + self.open() + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.close() + + def open(self): + if not os.path.exists(self.filename): + raise FileNotFoundError(f"Could not find {self.filename}") + + if self._file is None: + try: + self._file = h5py.File(self.filename, "r") + self._is_hdf5 = True + except OSError as e: + # We get OSError when trying to read non-HDF5 file with h5py + pass + + if self._file is None: + self._file = open(self.filename, "r") + self._is_hdf5 = False + + def close(self): + if self._file is not None: + self._file.close() + + @property + def file(self): + self.open() + return self._file + + @property + def text(self): + if self._text is None: + self._text = self.file.read() + return self._text + + @property + def json(self): + if self._json is None: + self._json = json_loads(self.text) + return self._json + + @property + def is_json(self): + if self._is_json is None: + try: + self.json + self._is_json = True + except Exception as e: + self._is_json = False + return self._is_json + + @property + def is_hdf5(self): + self.open() + return self._is_hdf5 + + @property + def format_id(self): + if self.is_hdf5: + if "metadata" in self.file: + meta_group = self.file.require_group("metadata") + if "format_id" in meta_group.attrs: + return meta_group.attrs["format_id"] + + elif self.is_json: + if "format_id" in self.json: + return self.json["format_id"] + + return None diff --git a/sleap/io/format/genericjson.py b/sleap/io/format/genericjson.py new file mode 100644 index 000000000..c4330c1d8 --- /dev/null +++ b/sleap/io/format/genericjson.py @@ -0,0 +1,42 @@ +from .adaptor import Adaptor, SleapObjectType +from .filehandle import FileHandle + +from sleap.util import json_dumps + + +class GenericJsonAdaptor(Adaptor): + @property + def handles(self): + return SleapObjectType.misc + + @property + def default_ext(self): + return "json" + + @property + def all_exts(self): + return ["json", "txt"] + + @property + def name(self): + return "JSON file" + + def can_read_file(self, file: FileHandle): + if not self.does_match_ext(file.filename): + return False + return file.is_json + + def can_write_filename(self, filename: str) -> bool: + return True + + def does_read(self) -> bool: + return True + + def does_write(self) -> bool: + return True + + def read(self, file: FileHandle, *args, **kwargs): + return file.json + + def write(self, filename: str, source_object: dict): + json_dumps(source_object, filename) diff --git a/sleap/io/format/hdf5.py b/sleap/io/format/hdf5.py new file mode 100644 index 000000000..e41c32c82 --- /dev/null +++ b/sleap/io/format/hdf5.py @@ -0,0 +1,450 @@ +from sleap.io import format +from . import labels_json + +from sleap.instance import ( + PointArray, + PredictedPointArray, + Instance, + PredictedInstance, + LabeledFrame, + PredictedPoint, + Point, +) +from sleap.util import json_loads, json_dumps +from sleap import Labels, Video + +import h5py +import numpy as np +import os + +from typing import Optional + + +class LabelsV1Adaptor(format.adaptor.Adaptor): + FORMAT_ID = 1 + + @property + def handles(self): + return format.adaptor.SleapObjectType.labels + + @property + def default_ext(self): + return "h5" + + @property + def all_exts(self): + return ["h5", "hdf5"] + + @property + def name(self): + return "Labels HDF5" + + def can_read_file(self, file: format.filehandle.FileHandle): + if not self.does_match_ext(file.filename): + return False + if not file.is_hdf5: + return False + if file.format_id not in (None, self.FORMAT_ID): + return False + return True + + def can_write_filename(self, filename: str): + return self.does_match_ext(filename) + + def does_read(self) -> bool: + return True + + def does_write(self) -> bool: + return True + + @classmethod + def read( + cls, + file: format.filehandle.FileHandle, + video_callback=None, + match_to: Optional[Labels] = None, + *args, + **kwargs, + ): + f = file.file + + # Extract the Labels JSON metadata and create Labels object with just + # this metadata. + dicts = json_loads( + f.require_group("metadata").attrs["json"].tostring().decode() + ) + + for key in ("videos", "tracks", "suggestions"): + hdf5_key = f"{key}_json" + if hdf5_key in f: + items = [json_loads(item_json) for item_json in f[hdf5_key]] + dicts[key] = items + + # Video path "." means the video is saved in same file as labels, + # so replace these paths. + for video_item in dicts["videos"]: + if video_item["backend"]["filename"] == ".": + video_item["backend"]["filename"] = file.filename + + if hasattr(video_callback, "__iter__"): + # If the callback is an iterable, then we'll expect it to be a + # list of strings and build a non-gui callback with those as + # the search paths. + search_paths = [path for path in video_callback] + video_callback = Labels.make_video_callback(search_paths) + + # Use the callback if given to handle missing videos + if callable(video_callback): + video_callback(dicts["videos"]) + + labels = labels_json.LabelsJsonAdaptor.from_json_data(dicts, match_to=match_to) + + frames_dset = f["frames"][:] + instances_dset = f["instances"][:] + points_dset = f["points"][:] + pred_points_dset = f["pred_points"][:] + + # Rather than instantiate a bunch of Point\PredictedPoint objects, we will + # use inplace numpy recarrays. This will save a lot of time and memory + # when reading things in. + points = PointArray(buf=points_dset, shape=len(points_dset)) + pred_points = PredictedPointArray( + buf=pred_points_dset, shape=len(pred_points_dset) + ) + + # Extend the tracks list with a None track. We will signify this with a -1 in the + # data which will map to last element of tracks + tracks = labels.tracks.copy() + tracks.extend([None]) + + # A dict to keep track of instances that have a from_predicted link. The key is the + # instance and the value is the index of the instance. + from_predicted_lookup = {} + + # Create the instances + instances = [] + for i in instances_dset: + track = tracks[i["track"]] + skeleton = labels.skeletons[i["skeleton"]] + + if i["instance_type"] == 0: # Instance + instance = Instance( + skeleton=skeleton, + track=track, + points=points[i["point_id_start"] : i["point_id_end"]], + ) + else: # PredictedInstance + instance = PredictedInstance( + skeleton=skeleton, + track=track, + points=pred_points[i["point_id_start"] : i["point_id_end"]], + score=i["score"], + ) + instances.append(instance) + + if i["from_predicted"] != -1: + from_predicted_lookup[instance] = i["from_predicted"] + + # Make a second pass to add any from_predicted links + for instance, from_predicted_idx in from_predicted_lookup.items(): + instance.from_predicted = instances[from_predicted_idx] + + # Create the labeled frames + frames = [ + LabeledFrame( + video=labels.videos[frame["video"]], + frame_idx=frame["frame_idx"], + instances=instances[ + frame["instance_id_start"] : frame["instance_id_end"] + ], + ) + for i, frame in enumerate(frames_dset) + ] + + labels.labeled_frames = frames + + # Do the stuff that should happen after we have labeled frames + labels._build_lookup_caches() + + return labels + + @classmethod + def write( + cls, + filename: str, + source_object: object, + append: bool = False, + save_frame_data: bool = False, + frame_data_format: str = "png", + ): + + labels = source_object + + # Delete the file if it exists, we want to start from scratch since + # h5py truncates the file which seems to not actually delete data + # from the file. Don't if we are appending of course. + if os.path.exists(filename) and not append: + os.unlink(filename) + + # Serialize all the meta-data to JSON. + d = labels.to_dict(skip_labels=True) + + if save_frame_data: + new_videos = labels.save_frame_data_hdf5(filename, frame_data_format) + + # Replace path to video file with "." (which indicates that the + # video is in the same file as the HDF5 labels dataset). + # Otherwise, the video paths will break if the HDF5 labels + # dataset file is moved. + for vid in new_videos: + vid.backend.filename = "." + + d["videos"] = Video.cattr().unstructure(new_videos) + + with h5py.File(filename, "a") as f: + + # Add all the JSON metadata + meta_group = f.require_group("metadata") + + meta_group.attrs["format_id"] = cls.FORMAT_ID + + # If we are appending and there already exists JSON metadata + if append and "json" in meta_group.attrs: + + # Otherwise, we need to read the JSON and append to the lists + old_labels = labels_json.LabelsJsonAdaptor.from_json_data( + meta_group.attrs["json"].tostring().decode() + ) + + # A function to join to list but only include new non-dupe entries + # from the right hand list. + def append_unique(old, new): + unique = [] + for x in new: + try: + matches = [y.matches(x) for y in old] + except AttributeError: + matches = [x == y for y in old] + + # If there were no matches, this is a unique object. + if sum(matches) == 0: + unique.append(x) + else: + # If we have an object that matches, replace the instance with + # the one from the new list. This will will make sure objects + # on the Instances are the same as those in the Labels lists. + for i, match in enumerate(matches): + if match: + old[i] = x + + return old + unique + + # Append the lists + labels.tracks = append_unique(old_labels.tracks, labels.tracks) + labels.skeletons = append_unique(old_labels.skeletons, labels.skeletons) + labels.videos = append_unique(old_labels.videos, labels.videos) + labels.nodes = append_unique(old_labels.nodes, labels.nodes) + + # FIXME: Do something for suggestions and negative_anchors + + # Get the dict for JSON and save it over the old data + d = labels.to_dict(skip_labels=True) + + if not append: + for key in ("videos", "tracks", "suggestions"): + # Convert for saving in hdf5 dataset + data = [np.string_(json_dumps(item)) for item in d[key]] + + hdf5_key = f"{key}_json" + + # Save in its own dataset (e.g., videos_json) + f.create_dataset(hdf5_key, data=data, maxshape=(None,)) + + # Clear from dict since we don't want to save this in attribute + d[key] = [] + + # Output the dict to JSON + meta_group.attrs["json"] = np.string_(json_dumps(d)) + + # FIXME: We can probably construct these from attrs fields + # We will store Instances and PredcitedInstances in the same + # table. instance_type=0 or Instance and instance_type=1 for + # PredictedInstance, score will be ignored for Instances. + instance_dtype = np.dtype( + [ + ("instance_id", "i8"), + ("instance_type", "u1"), + ("frame_id", "u8"), + ("skeleton", "u4"), + ("track", "i4"), + ("from_predicted", "i8"), + ("score", "f4"), + ("point_id_start", "u8"), + ("point_id_end", "u8"), + ] + ) + frame_dtype = np.dtype( + [ + ("frame_id", "u8"), + ("video", "u4"), + ("frame_idx", "u8"), + ("instance_id_start", "u8"), + ("instance_id_end", "u8"), + ] + ) + + num_instances = len(labels.all_instances) + max_skeleton_size = max([len(s.nodes) for s in labels.skeletons], default=0) + + # Initialize data arrays for serialization + points = np.zeros(num_instances * max_skeleton_size, dtype=Point.dtype) + pred_points = np.zeros( + num_instances * max_skeleton_size, dtype=PredictedPoint.dtype + ) + instances = np.zeros(num_instances, dtype=instance_dtype) + frames = np.zeros(len(labels), dtype=frame_dtype) + + # Pre compute some structures to make serialization faster + skeleton_to_idx = { + skeleton: labels.skeletons.index(skeleton) + for skeleton in labels.skeletons + } + track_to_idx = { + track: labels.tracks.index(track) for track in labels.tracks + } + track_to_idx[None] = -1 + video_to_idx = { + video: labels.videos.index(video) for video in labels.videos + } + instance_type_to_idx = {Instance: 0, PredictedInstance: 1} + + # Each instance we create will have and index in the dataset, keep track of + # these so we can quickly add from_predicted links on a second pass. + instance_to_idx = {} + instances_with_from_predicted = [] + instances_from_predicted = [] + + # If we are appending, we need look inside to see what frame, instance, and point + # ids we need to start from. This gives us offsets to use. + if append and "points" in f: + point_id_offset = f["points"].shape[0] + pred_point_id_offset = f["pred_points"].shape[0] + instance_id_offset = f["instances"][-1]["instance_id"] + 1 + frame_id_offset = int(f["frames"][-1]["frame_id"]) + 1 + else: + point_id_offset = 0 + pred_point_id_offset = 0 + instance_id_offset = 0 + frame_id_offset = 0 + + point_id = 0 + pred_point_id = 0 + instance_id = 0 + + for frame_id, label in enumerate(labels): + frames[frame_id] = ( + frame_id + frame_id_offset, + video_to_idx[label.video], + label.frame_idx, + instance_id + instance_id_offset, + instance_id + instance_id_offset + len(label.instances), + ) + for instance in label.instances: + + # Add this instance to our lookup structure we will need for from_predicted + # links + instance_to_idx[instance] = instance_id + + parray = instance.get_points_array(copy=False, full=True) + instance_type = type(instance) + + # Check whether we are working with a PredictedInstance or an Instance. + if instance_type is PredictedInstance: + score = instance.score + pid = pred_point_id + pred_point_id_offset + else: + score = np.nan + pid = point_id + point_id_offset + + # Keep track of any from_predicted instance links, we will insert the + # correct instance_id in the dataset after we are done. + if instance.from_predicted: + instances_with_from_predicted.append(instance_id) + instances_from_predicted.append(instance.from_predicted) + + # Copy all the data + instances[instance_id] = ( + instance_id + instance_id_offset, + instance_type_to_idx[instance_type], + frame_id, + skeleton_to_idx[instance.skeleton], + track_to_idx[instance.track], + -1, + score, + pid, + pid + len(parray), + ) + + # If these are predicted points, copy them to the predicted point array + # otherwise, use the normal point array + if type(parray) is PredictedPointArray: + pred_points[ + pred_point_id : pred_point_id + len(parray) + ] = parray + pred_point_id = pred_point_id + len(parray) + else: + points[point_id : point_id + len(parray)] = parray + point_id = point_id + len(parray) + + instance_id = instance_id + 1 + + # Add from_predicted links + for instance_id, from_predicted in zip( + instances_with_from_predicted, instances_from_predicted + ): + try: + instances[instance_id]["from_predicted"] = instance_to_idx[ + from_predicted + ] + except KeyError: + # If we haven't encountered the from_predicted instance yet then don't save the link. + # It’s possible for a user to create a regular instance from a predicted instance and then + # delete all predicted instances from the file, but in this case I don’t think there’s any reason + # to remember which predicted instance the regular instance came from. + pass + + # We pre-allocated our points array with max possible size considering the max + # skeleton size, drop any unused points. + points = points[0:point_id] + pred_points = pred_points[0:pred_point_id] + + # Create datasets if we need to + if append and "points" in f: + f["points"].resize((f["points"].shape[0] + points.shape[0]), axis=0) + f["points"][-points.shape[0] :] = points + f["pred_points"].resize( + (f["pred_points"].shape[0] + pred_points.shape[0]), axis=0 + ) + f["pred_points"][-pred_points.shape[0] :] = pred_points + f["instances"].resize( + (f["instances"].shape[0] + instances.shape[0]), axis=0 + ) + f["instances"][-instances.shape[0] :] = instances + f["frames"].resize((f["frames"].shape[0] + frames.shape[0]), axis=0) + f["frames"][-frames.shape[0] :] = frames + else: + f.create_dataset( + "points", data=points, maxshape=(None,), dtype=Point.dtype + ) + f.create_dataset( + "pred_points", + data=pred_points, + maxshape=(None,), + dtype=PredictedPoint.dtype, + ) + f.create_dataset( + "instances", data=instances, maxshape=(None,), dtype=instance_dtype + ) + f.create_dataset( + "frames", data=frames, maxshape=(None,), dtype=frame_dtype + ) diff --git a/sleap/io/format/labels_json.py b/sleap/io/format/labels_json.py new file mode 100644 index 000000000..a2e09b539 --- /dev/null +++ b/sleap/io/format/labels_json.py @@ -0,0 +1,468 @@ +import atexit +import os +import re +import shutil +import tempfile +import zipfile +from typing import Optional, Union, Dict, List, Callable + +import cattr + +from .adaptor import Adaptor, SleapObjectType +from .filehandle import FileHandle + +from sleap import Labels, Video +from sleap.gui.suggestions import SuggestionFrame +from sleap.instance import ( + LabeledFrame, + Track, + make_instance_cattr, +) +from sleap.io.legacy import load_labels_json_old +from sleap.skeleton import Node, Skeleton +from sleap.util import json_loads, json_dumps, weak_filename_match + + +class LabelsJsonAdaptor(Adaptor): + FORMAT_ID = 1 + + @property + def handles(self): + return SleapObjectType.labels + + @property + def default_ext(self): + return "json" + + @property + def all_exts(self): + return ["json", "json.zip"] + + @property + def name(self): + return "Labels JSON" + + def can_read_file(self, file: FileHandle): + if not self.does_match_ext(file.filename): + print(f"{file.filename} doesn't match ext for json or json.zip") + return False + + if file.filename.endswith(".zip"): + # We can't check inside zip so assume it's correct + return True + + if not file.is_json: + return False + if file.format_id not in (None, self.FORMAT_ID): + return False + return True + + def can_write_filename(self, filename: str): + return self.does_match_ext(filename) + + def does_read(self) -> bool: + return True + + def does_write(self) -> bool: + return True + + @classmethod + def read( + cls, + file: FileHandle, + video_callback: Optional[Callable] = None, + match_to: Optional[Labels] = None, + *args, + **kwargs, + ) -> Labels: + pass + + """ + Deserialize JSON file as new :class:`Labels` instance. + + Args: + filename: Path to JSON file. + video_callback: A callback function that which can modify + video paths before we try to create the corresponding + :class:`Video` objects. Usually you'll want to pass + a callback created by :meth:`make_video_callback` + or :meth:`make_gui_video_callback`. + Alternately, if you pass a list of strings we'll construct a + non-gui callback with those strings as the search paths. + match_to: If given, we'll replace particular objects in the + data dictionary with *matching* objects in the match_to + :class:`Labels` object. This ensures that the newly + instantiated :class:`Labels` can be merged without + duplicate matching objects (e.g., :class:`Video` objects ). + Returns: + A new :class:`Labels` object. + """ + + tmp_dir = None + filename = file.filename + + # Check if the file is a zipfile for not. + if zipfile.is_zipfile(filename): + + # Make a tmpdir, located in the directory that the file exists, to unzip + # its contents. + tmp_dir = os.path.join( + os.path.dirname(filename), + f"tmp_{os.getpid()}_{os.path.basename(filename)}", + ) + if os.path.exists(tmp_dir): + shutil.rmtree(tmp_dir, ignore_errors=True) + try: + os.mkdir(tmp_dir) + except FileExistsError: + pass + + # tmp_dir = tempfile.mkdtemp(dir=os.path.dirname(filename)) + + try: + + # Register a cleanup routine that deletes the tmpdir on program exit + # if something goes wrong. The True is for ignore_errors + atexit.register(shutil.rmtree, tmp_dir, True) + + # Uncompress the data into the directory + shutil.unpack_archive(filename, extract_dir=tmp_dir) + + # We can now open the JSON file, save the zip file and + # replace file with the first JSON file we find in the archive. + json_files = [ + os.path.join(tmp_dir, file) + for file in os.listdir(tmp_dir) + if file.endswith(".json") + ] + + if len(json_files) == 0: + raise ValueError( + f"No JSON file found inside {filename}. Are you sure this is a valid sLEAP dataset." + ) + + filename = json_files[0] + + except Exception as ex: + # If we had problems, delete the temp directory and reraise the exception. + shutil.rmtree(tmp_dir, ignore_errors=True) + raise + + # Open and parse the JSON in filename + with open(filename, "r") as file: + + # FIXME: Peek into the json to see if there is version string. + # We do this to tell apart old JSON data from leap_dev vs the + # newer format for sLEAP. + json_str = file.read() + dicts = json_loads(json_str) + + # If we have a version number, then it is new sLEAP format + if "version" in dicts: + + # Cache the working directory. + cwd = os.getcwd() + # Replace local video paths (for imagestore) + if tmp_dir: + for vid in dicts["videos"]: + vid["backend"]["filename"] = os.path.join( + tmp_dir, vid["backend"]["filename"] + ) + + if hasattr(video_callback, "__iter__"): + # If the callback is an iterable, then we'll expect it to be a + # list of strings and build a non-gui callback with those as + # the search paths. + search_paths = [path for path in video_callback] + video_callback = Labels.make_video_callback(search_paths) + + # Use the callback if given to handle missing videos + if callable(video_callback): + abort = video_callback(dicts["videos"]) + if abort: + raise FileNotFoundError + + # Try to load the labels filename. + try: + labels = cls.from_json_data(dicts, match_to=match_to) + + except FileNotFoundError: + + # FIXME: We are going to the labels JSON that has references to + # video files. Lets change directory to the dirname of the json file + # so that relative paths will be from this directory. Maybe + # it is better to feed the dataset dirname all the way down to + # the Video object. This seems like less coupling between classes + # though. + if os.path.dirname(filename) != "": + os.chdir(os.path.dirname(filename)) + + # Try again + labels = cls.from_json_data(dicts, match_to=match_to) + + except Exception as ex: + # Ok, we give up, where the hell are these videos! + raise # Re-raise. + finally: + os.chdir(cwd) # Make sure to change back if we have problems. + + return labels + + else: + frames = load_labels_json_old(data_path=filename, parsed_json=dicts) + return Labels(frames) + + @classmethod + def write( + cls, + filename: str, + source_object: str, + compress: Optional[bool] = None, + save_frame_data: bool = False, + frame_data_format: str = "png", + ): + """ + Save a Labels instance to a JSON format. + + Args: + filename: The filename to save the data to. + source_object: The labels dataset to save. + compress: Whether the data be zip compressed or not? If True, + the JSON will be compressed using Python's shutil.make_archive + command into a PKZIP zip file. If compress is True then + filename will have a .zip appended to it. + save_frame_data: Whether to save the image data for each frame. + For each video in the dataset, all frames that have labels + will be stored as an imgstore dataset. + If save_frame_data is True then compress will be forced to True + since the archive must contain both the JSON data and image + data stored in ImgStores. + frame_data_format: If save_frame_data is True, then this argument + is used to set the data format to use when writing frame + data to ImgStore objects. Supported formats should be: + + * 'pgm', + * 'bmp', + * 'ppm', + * 'tif', + * 'png', + * 'jpg', + * 'npy', + * 'mjpeg/avi', + * 'h264/mkv', + * 'avc1/mp4' + + Note: 'h264/mkv' and 'avc1/mp4' require separate installation + of these codecs on your system. They are excluded from SLEAP + because of their GPL license. + + Returns: + None + """ + + labels = source_object + + if compress is None: + compress = filename.endswith(".zip") + + # Lets make a temporary directory to store the image frame data or pre-compressed json + # in case we need it. + with tempfile.TemporaryDirectory() as tmp_dir: + + # If we are saving frame data along with the datasets. We will replace videos with + # new video object that represent video data from just the labeled frames. + if save_frame_data: + + # Create a set of new Video objects with imgstore backends. One for each + # of the videos. We will only include the labeled frames though. We will + # then replace each video with this new video + new_videos = labels.save_frame_data_imgstore( + output_dir=tmp_dir, format=frame_data_format + ) + + # Make video paths relative + for vid in new_videos: + tmp_path = vid.filename + # Get the parent dir of the YAML file. + # Use "/" since this works on Windows and posix + img_store_dir = ( + os.path.basename(os.path.split(tmp_path)[0]) + + "/" + + os.path.basename(tmp_path) + ) + # Change to relative path + vid.backend.filename = img_store_dir + + # Convert to a dict, not JSON yet, because we need to patch up the videos + d = labels.to_dict() + d["videos"] = Video.cattr().unstructure(new_videos) + + else: + d = labels.to_dict() + + # Set file format version + d["format_id"] = cls.FORMAT_ID + + if compress or save_frame_data: + + # Ensure that filename ends with .json + # shutil will append .zip + filename = re.sub("(\.json)?(\.zip)?$", ".json", filename) + + # Write the json to the tmp directory, we will zip it up with the frame data. + full_out_filename = os.path.join(tmp_dir, os.path.basename(filename)) + json_dumps(d, full_out_filename) + + # Create the archive + shutil.make_archive(base_name=filename, root_dir=tmp_dir, format="zip") + + # If the user doesn't want to compress, then just write the json to the filename + else: + json_dumps(d, filename) + + @classmethod + def from_json_data( + cls, data: Union[str, dict], match_to: Optional["Labels"] = None + ) -> "Labels": + """ + Create instance of class from data in dictionary. + + Method is used by other methods that load from JSON. + + Args: + data: Dictionary, deserialized from JSON. + match_to: If given, we'll replace particular objects in the + data dictionary with *matching* objects in the match_to + :class:`Labels` object. This ensures that the newly + instantiated :class:`Labels` can be merged without + duplicate matching objects (e.g., :class:`Video` objects ). + Returns: + A new :class:`Labels` object. + """ + + # Parse the json string if needed. + if type(data) is str: + dicts = json_loads(data) + else: + dicts = data + + dicts["tracks"] = dicts.get( + "tracks", [] + ) # don't break if json doesn't include tracks + + # First, deserialize the skeletons, videos, and nodes lists. + # The labels reference these so we will need them while deserializing. + nodes = cattr.structure(dicts["nodes"], List[Node]) + + idx_to_node = {i: nodes[i] for i in range(len(nodes))} + skeletons = Skeleton.make_cattr(idx_to_node).structure( + dicts["skeletons"], List[Skeleton] + ) + videos = Video.cattr().structure(dicts["videos"], List[Video]) + + try: + # First try unstructuring tuple (newer format) + track_cattr = cattr.Converter( + unstruct_strat=cattr.UnstructureStrategy.AS_TUPLE + ) + tracks = track_cattr.structure(dicts["tracks"], List[Track]) + except: + # Then try unstructuring dict (older format) + try: + tracks = cattr.structure(dicts["tracks"], List[Track]) + except: + raise ValueError("Unable to load tracks as tuple or dict!") + + # if we're given a Labels object to match, use its objects when they match + if match_to is not None: + for idx, sk in enumerate(skeletons): + for old_sk in match_to.skeletons: + if sk.matches(old_sk): + # use nodes from matched skeleton + for (node, match_node) in zip(sk.nodes, old_sk.nodes): + node_idx = nodes.index(node) + nodes[node_idx] = match_node + # use skeleton from match + skeletons[idx] = old_sk + break + for idx, vid in enumerate(videos): + for old_vid in match_to.videos: + # compare last three parts of path + if vid.filename == old_vid.filename or weak_filename_match( + vid.filename, old_vid.filename + ): + # use video from match + videos[idx] = old_vid + break + + suggestions = [] + if "suggestions" in dicts: + suggestions_cattr = cattr.Converter() + suggestions_cattr.register_structure_hook( + Video, lambda x, type: videos[int(x)] + ) + try: + suggestions = suggestions_cattr.structure( + dicts["suggestions"], List[SuggestionFrame] + ) + except Exception as e: + print("Error while loading suggestions (1)") + print(e) + + try: + # Convert old suggestion format to new format. + # Old format: {video: list of frame indices} + # New format: [SuggestionFrames] + old_suggestions = suggestions_cattr.structure( + dicts["suggestions"], Dict[Video, List] + ) + for video in old_suggestions.keys(): + suggestions.extend( + [ + SuggestionFrame(video, idx) + for idx in old_suggestions[video] + ] + ) + except Exception as e: + print("Error while loading suggestions (2)") + print(e) + pass + + if "negative_anchors" in dicts: + negative_anchors_cattr = cattr.Converter() + negative_anchors_cattr.register_structure_hook( + Video, lambda x, type: videos[int(x)] + ) + negative_anchors = negative_anchors_cattr.structure( + dicts["negative_anchors"], Dict[Video, List] + ) + else: + negative_anchors = dict() + + # If there is actual labels data, get it. + if "labels" in dicts: + label_cattr = make_instance_cattr() + label_cattr.register_structure_hook( + Skeleton, lambda x, type: skeletons[int(x)] + ) + label_cattr.register_structure_hook(Video, lambda x, type: videos[int(x)]) + label_cattr.register_structure_hook( + Node, lambda x, type: x if isinstance(x, Node) else nodes[int(x)] + ) + label_cattr.register_structure_hook( + Track, lambda x, type: None if x is None else tracks[int(x)] + ) + + labels = label_cattr.structure(dicts["labels"], List[LabeledFrame]) + else: + labels = [] + + return Labels( + labeled_frames=labels, + videos=videos, + skeletons=skeletons, + nodes=nodes, + suggestions=suggestions, + negative_anchors=negative_anchors, + tracks=tracks, + ) diff --git a/sleap/io/format/leap_matlab.py b/sleap/io/format/leap_matlab.py new file mode 100644 index 000000000..ece6e9a17 --- /dev/null +++ b/sleap/io/format/leap_matlab.py @@ -0,0 +1,133 @@ +import os + +import scipy.io as sio + +from sleap import Labels, Video, Skeleton +from sleap.gui.missingfiles import MissingFilesDialog +from sleap.instance import ( + Instance, + LabeledFrame, + Point, +) +from .adaptor import Adaptor, SleapObjectType +from .filehandle import FileHandle + + +class LabelsLeapMatlabAdaptor(Adaptor): + @property + def handles(self): + return SleapObjectType.labels + + @property + def default_ext(self): + return "mat" + + @property + def all_exts(self): + return ["mat"] + + @property + def name(self): + return "LEAP Matlab dataset" + + def can_read_file(self, file: FileHandle): + if not self.does_match_ext(file.filename): + return False + # if "boxPath" not in file.file: + # return False + return True + + def can_write_filename(self, filename: str): + return self.does_match_ext(filename) + + def does_read(self) -> bool: + return True + + def does_write(self) -> bool: + return False + + @classmethod + def read( + cls, file: FileHandle, gui: bool = True, *args, **kwargs, + ): + filename = file.filename + + mat_contents = sio.loadmat(filename) + + box_path = cls._unwrap_mat_scalar(mat_contents["boxPath"]) + + # If the video file isn't found, try in the same dir as the mat file + if not os.path.exists(box_path): + file_dir = os.path.dirname(filename) + box_path_name = box_path.split("\\")[-1] # assume windows path + box_path = os.path.join(file_dir, box_path_name) + + if not os.path.exists(box_path): + if gui: + video_paths = [box_path] + missing = [True] + okay = MissingFilesDialog(video_paths, missing).exec_() + + if not okay or missing[0]: + return + + box_path = video_paths[0] + else: + # Ignore missing videos if not loading from gui + box_path = "" + + if os.path.exists(box_path): + vid = Video.from_hdf5( + dataset="box", filename=box_path, input_format="channels_first" + ) + else: + vid = None + + nodes_ = mat_contents["skeleton"]["nodes"] + edges_ = mat_contents["skeleton"]["edges"] + points_ = mat_contents["positions"] + + edges_ = edges_ - 1 # convert matlab 1-indexing to python 0-indexing + + nodes = cls._unwrap_mat_array(nodes_) + edges = cls._unwrap_mat_array(edges_) + + nodes = list(map(str, nodes)) # convert np._str to str + + sk = Skeleton(name=filename) + sk.add_nodes(nodes) + for edge in edges: + sk.add_edge(source=nodes[edge[0]], destination=nodes[edge[1]]) + + labeled_frames = [] + node_count, _, frame_count = points_.shape + + for i in range(frame_count): + new_inst = Instance(skeleton=sk) + for node_idx, node in enumerate(nodes): + x = points_[node_idx][0][i] + y = points_[node_idx][1][i] + new_inst[node] = Point(x, y) + if len(new_inst.points): + new_frame = LabeledFrame(video=vid, frame_idx=i) + new_frame.instances = (new_inst,) + labeled_frames.append(new_frame) + + labels = Labels(labeled_frames=labeled_frames, videos=[vid], skeletons=[sk]) + + return labels + + @classmethod + def _unwrap_mat_scalar(cls, a): + """Extract single value from nested MATLAB file data.""" + if a.shape == (1,): + return cls._unwrap_mat_scalar(a[0]) + else: + return a + + @classmethod + def _unwrap_mat_array(cls, a): + """Extract list of values from nested MATLAB file data.""" + b = a[0][0] + c = [cls._unwrap_mat_scalar(x) for x in b] + return c diff --git a/sleap/io/format/text.py b/sleap/io/format/text.py new file mode 100644 index 000000000..f72ab4773 --- /dev/null +++ b/sleap/io/format/text.py @@ -0,0 +1,39 @@ +from .adaptor import Adaptor, SleapObjectType +from .filehandle import FileHandle + + +class TextAdaptor(Adaptor): + @property + def handles(self): + return SleapObjectType.misc + + @property + def default_ext(self): + return "txt" + + @property + def all_exts(self): + return ["txt", "log"] + + @property + def name(self): + return "Text file" + + def can_read_file(self, file: FileHandle): + return True # FIXME + + def can_write_filename(self, filename: str) -> bool: + return True + + def does_read(self) -> bool: + return True + + def does_write(self) -> bool: + return True + + def read(self, file: FileHandle, *args, **kwargs): + return file.text + + def write(self, filename: str, source_object: str): + with open(filename, "w") as f: + f.write(source_object) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 32d7e3d48..3e123267c 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -47,6 +47,7 @@ class Predictor: * "confmap": Instance of `peak_finding.ConfmapPeakFinder` * "paf": Instance of `paf_grouping.PAFGrouper` * "tracking": Instance of `tracking.Tracker` + * "previous_predictions": `predicted.PredictedInstancePredictor` Note: the pipeline will be determined by which policies are given. """ @@ -77,15 +78,27 @@ def predict( if self.has_grayscale_models: video_kwargs["grayscale"] = True + is_dummy_video = False + if "tracking" in self.policies and "previous_predictions" in self.policies: + if not self.policies["tracking"].uses_image: + # We're just running the tracker for previous predictions + # and this tracker doesn't use the images, so we'll load + # "dummy" images. + is_dummy_video = True + video_ds = utils.VideoLoader( - filename=video_filename, frame_inds=frames, **video_kwargs + filename=video_filename, + frame_inds=frames, + dummy=is_dummy_video, + **video_kwargs, ) predicted_frames = [] for chunk_ind, frame_inds, imgs in video_ds: + predicted_instances_chunk = self.predict_chunk( - imgs, chunk_ind, video_ds.chunk_size + imgs, chunk_ind, video_ds.chunk_size, frame_inds=frame_inds ) sample_inds = np.arange(len(imgs)) @@ -100,9 +113,12 @@ def predict( return predicted_frames - def predict_chunk(self, img_chunk, chunk_ind, chunk_size): + def predict_chunk(self, img_chunk, chunk_ind, chunk_size, frame_inds=None): """Runs the inference components of pipeline for a chunk.""" + if "previous_predictions" in self.policies: + return self.policies["previous_predictions"].get_chunk(frame_inds) + if "centroid" in self.policies: # Detect centroids and pull out region proposals. centroid_predictor = self.policies["centroid"] @@ -306,7 +322,6 @@ def frame_list(frame_str: str): action="append", help="Path to saved model (confmaps, pafs, ...) JSON. " "Multiple models can be specified, each preceded by --model.", - required=True, ) parser.add_argument( @@ -350,10 +365,14 @@ def frame_list(frame_str: str): @classmethod def cli_args_to_policies(cls, args): policy_args = util.make_scoped_dictionary(vars(args), exclude_nones=True) - return cls.from_paths_and_policy_args(args.models, policy_args) + return cls.from_paths_and_policy_args( + model_paths=args.models, policy_args=policy_args, args=args, + ) @classmethod - def from_paths_and_policy_args(cls, model_paths: List[str], policy_args: dict): + def from_paths_and_policy_args( + cls, model_paths: List[str], policy_args: dict, args: dict + ): policy_args["region"]["merge_overlapping"] = True inferred_box_length = 160 # default if not set by user or inferrable @@ -369,35 +388,55 @@ def from_paths_and_policy_args(cls, model_paths: List[str], policy_args: dict): # Load the information for these models loaded_models = dict() - for model_path in model_paths: - training_job = job.TrainingJob.load_json(model_path) - inference_model = model.InferenceModel.from_training_job(training_job) - policy_key = model_type_policy_key_map[training_job.model.output_type] - loaded_models[policy_key] = dict( - job=training_job, inference_model=inference_model - ) + if model_paths: + for model_path in model_paths: + training_job = job.TrainingJob.load_json(model_path) + inference_model = model.InferenceModel.from_training_job(training_job) + policy_key = model_type_policy_key_map[training_job.model.output_type] + + loaded_models[policy_key] = dict( + job=training_job, inference_model=inference_model + ) - # Add policy classes which depend on models - for policy_key, policy_model in loaded_models.items(): - training_job = policy_model["job"] - inference_model = policy_model["inference_model"] + # Add policy classes which depend on models + for policy_key, policy_model in loaded_models.items(): + training_job = policy_model["job"] + inference_model = policy_model["inference_model"] - if policy_key == "confmap" and "paf" not in loaded_models.keys(): - # Use topdown class when we have confmaps and not pafs - policy_class = POLICY_CLASSES["topdown"] - else: - policy_class = POLICY_CLASSES[policy_key] + if policy_key == "confmap" and "paf" not in loaded_models.keys(): + # Use topdown class when we have confmaps and not pafs + policy_class = POLICY_CLASSES["topdown"] + else: + policy_class = POLICY_CLASSES[policy_key] - policy_object = policy_class( - inference_model=inference_model, **policy_args[policy_key] - ) + policy_object = policy_class( + inference_model=inference_model, **policy_args[policy_key] + ) - policies[policy_key] = policy_object + policies[policy_key] = policy_object - if training_job.trainer.bounding_box_size is not None: - if training_job.trainer.bounding_box_size > 0: - inferred_box_length = training_job.trainer.bounding_box_size + if training_job.trainer.bounding_box_size is not None: + if training_job.trainer.bounding_box_size > 0: + inferred_box_length = training_job.trainer.bounding_box_size + + # No models specified so see if we're using previous predictions + else: + try: + previous_labels = Labels.load_file( + args.data_path, video_callback=[os.path.dirname(args.data_path)], + ) + from .predicted import PredictedInstancePredictor + + policies["previous_predictions"] = PredictedInstancePredictor( + labels=previous_labels, + ) + print(f"Using previous predictions from {args.data_path}") + args.data_path = previous_labels.videos[0].filename + print(f"Setting video to {args.data_path}") + except Exception: + # We weren't able to read file as Labels object + pass if "topdown" in policies: policy_args["region"]["merge_overlapping"] = False @@ -492,22 +531,30 @@ def predict_subprocess( def check_valid_policies(cls, policies: dict) -> bool: has_topdown = "topdown" in policies - + has_previous = "previous_predictions" in policies + has_tracker = "tracking" in policies non_topdowns = [key for key in policies.keys() if key in ("confmap", "paf")] - if has_topdown and non_topdowns: - raise ValueError( - f"Cannot combine topdown model with non-topdown model" - f" {non_topdowns}." - ) + if has_previous: + if not has_tracker: + raise ValueError( + f"No tracker specified for running on previous predictions" + ) - if non_topdowns and "confmap" not in non_topdowns: - raise ValueError("Must have CONFIDENCE_MAP model.") + else: + if has_topdown and non_topdowns: + raise ValueError( + f"Cannot combine topdown model with non-topdown model" + f" {non_topdowns}." + ) - if not has_topdown and not non_topdowns: - raise ValueError( - f"Must have either TOPDOWN or CONFIDENCE_MAP/PART_AFFINITY_FIELD models." - ) + if non_topdowns and "confmap" not in non_topdowns: + raise ValueError("Must have CONFIDENCE_MAP model.") + + if not has_topdown and not non_topdowns: + raise ValueError( + f"Must have either TOPDOWN or CONFIDENCE_MAP/PART_AFFINITY_FIELD models." + ) return True diff --git a/sleap/nn/predicted.py b/sleap/nn/predicted.py new file mode 100644 index 000000000..55a113cbe --- /dev/null +++ b/sleap/nn/predicted.py @@ -0,0 +1,25 @@ +import attr + + +@attr.s(auto_attribs=True) +class PredictedInstancePredictor: + """ + Returns chunk of previously generated predictions in format of Predictor. + """ + + labels: "Labels" + video_idx: int = 0 + + def get_chunk(self, frame_inds): + video = self.labels.videos[self.video_idx] + + # Return dict keyed to sample index (i.e., offset in frame_inds), value + # is the list of instances for that frame. + return { + i: [ + inst + for lf in self.labels.find(video=video, frame_idx=int(frame_idx)) + for inst in lf.instances + ] + for i, frame_idx in enumerate(frame_inds) + } diff --git a/sleap/nn/utils.py b/sleap/nn/utils.py index 39831f991..777a9bf07 100644 --- a/sleap/nn/utils.py +++ b/sleap/nn/utils.py @@ -309,6 +309,7 @@ class VideoLoader: dataset: str = None input_format: str = None grayscale: bool = False + dummy: bool = False chunk_size: int = 32 prefetch_chunks: int = 1 frame_inds: Optional[List[int]] = None @@ -367,8 +368,12 @@ def _load_video(self, filename) -> "Video": ) def load_frames(self, frame_inds): - local_vid = self._load_video(self.video.filename) - imgs = local_vid[np.array(frame_inds).astype("int64")] + if self.dummy: + dummy_shape = (len(frame_inds), *self._shape[1:]) + imgs = np.zeros(dummy_shape, dtype="int8") + else: + local_vid = self._load_video(self.video.filename) + imgs = local_vid[np.array(frame_inds).astype("int64")] return imgs def tf_load_frames(self, frame_inds): diff --git a/tests/data/hdf5_format_v1/centered_pair_predictions.h5 b/tests/data/hdf5_format_v1/centered_pair_predictions.h5 new file mode 100644 index 000000000..bf92b4592 Binary files /dev/null and b/tests/data/hdf5_format_v1/centered_pair_predictions.h5 differ diff --git a/tests/io/test_formats.py b/tests/io/test_formats.py new file mode 100644 index 000000000..72e04e391 --- /dev/null +++ b/tests/io/test_formats.py @@ -0,0 +1,108 @@ +from sleap.io.format import dispatch, adaptor, text, genericjson +import pytest +import os + + +def test_text_adaptor(tmpdir): + disp = dispatch.Dispatch() + disp.register(text.TextAdaptor()) + + filename = os.path.join(tmpdir, "textfile.txt") + some_text = "some text to save in a file" + + disp.write(filename, some_text) + + read_text = disp.read(filename) + + assert some_text == read_text + + +def test_json_adaptor(tmpdir): + disp = dispatch.Dispatch() + disp.register(genericjson.GenericJsonAdaptor()) + + filename = os.path.join(tmpdir, "jsonfile.json") + d = dict(foo=123, bar="zip") + + disp.write(filename, d) + + read_dict = disp.read(filename) + + assert d == read_dict + + assert disp.open(filename).is_json + + +def test_invalid_json(tmpdir): + # Write an "invalid" json file + filename = os.path.join(tmpdir, "textfile.json") + some_text = "some text to save in a file" + with open(filename, "w") as f: + f.write(some_text) + + disp = dispatch.Dispatch() + disp.register(genericjson.GenericJsonAdaptor()) + + assert not disp.open(filename).is_json + + with pytest.raises(TypeError): + disp.read(filename) + + +def test_no_matching_adaptor(): + disp = dispatch.Dispatch() + + with pytest.raises(TypeError): + disp.write("foo.txt", "foo") + + err = disp.write_safely("foo.txt", "foo") + + assert err is not None + + +def test_failed_read(): + disp = dispatch.Dispatch() + disp.register(text.TextAdaptor()) + + # Attempt to read hdf5 using text adaptor + hdf5_filename = "tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5" + x, err = disp.read_safely(hdf5_filename) + + # There should be an error + assert err is not None + + +def test_missing_file(): + disp = dispatch.Dispatch() + disp.register(text.TextAdaptor()) + + with pytest.raises(FileNotFoundError): + disp.read("missing_file.txt") + + +def test_hdf5_v1(tmpdir): + filename = "tests/data/hdf5_format_v1/centered_pair_predictions.h5" + disp = dispatch.Dispatch.make_dispatcher(adaptor.SleapObjectType.labels) + + # Make sure reading works + x = disp.read(filename) + assert len(x.labeled_frames) == 1100 + + # Make sure writing works + filename = os.path.join(tmpdir, "test.h5") + disp.write(filename, x) + + # Make sure we can read the file we just wrote + y = disp.read(filename) + assert len(y.labeled_frames) == 1100 + + +def test_json_v1(tmpdir, centered_pair_labels): + filename = os.path.join(tmpdir, "test.json") + disp = dispatch.Dispatch.make_dispatcher(adaptor.SleapObjectType.labels) + + disp.write(filename, centered_pair_labels) + + # Make sure we can read the file we just wrote + y = disp.read(filename) + assert len(y.labeled_frames) == len(centered_pair_labels.labeled_frames) diff --git a/tests/nn/test_utils.py b/tests/nn/test_utils.py index 8701a77a5..488ae1da7 100644 --- a/tests/nn/test_utils.py +++ b/tests/nn/test_utils.py @@ -1,4 +1,5 @@ from sleap.nn.utils import VideoLoader +import numpy as np def test_grayscale_video(): @@ -7,3 +8,11 @@ def test_grayscale_video(): vid = VideoLoader(filename="tests/data/videos/small_robot.mp4", grayscale=True) assert vid.shape[-1] == 1 + + +def test_dummy_video(): + vid = VideoLoader(filename="tests/data/videos/small_robot.mp4", dummy=True) + + x = vid.load_frames([1, 3, 5]) + assert x.shape == (3, 320, 560, 3) + assert np.all(x == 0)