Skip to content

Commit

Permalink
support loading slp files with non-compound types and str in metadata (
Browse files Browse the repository at this point in the history
…#1566)

Co-authored-by: Liezl Maree <38435167+roomrys@users.noreply.github.com>
  • Loading branch information
lambdaloop and roomrys committed Jan 5, 2024
1 parent 16241e0 commit 14b5b78
Showing 1 changed file with 47 additions and 4 deletions.
51 changes: 47 additions & 4 deletions sleap/io/format/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def read_headers(

# Extract the Labels JSON metadata and create Labels object with just this
# metadata.
dicts = json_loads(f.require_group("metadata").attrs["json"].tobytes().decode())
json = f.require_group("metadata").attrs["json"]
if not isinstance(json, str):
json = json.tobytes().decode()
dicts = json_loads(json)

# These items are stored in separate lists because the metadata group got to be
# too big.
Expand Down Expand Up @@ -151,6 +154,45 @@ def read(
points_dset[:]["x"] -= 0.5
points_dset[:]["y"] -= 0.5

def cast_as_compound(arr, dtype):
out = np.empty(shape=(len(arr),), dtype=dtype)
if out.size == 0:
return out
for i, (name, _) in enumerate(dtype):
out[name] = arr[:, i]
return out

# cast points, instances, and frames into complex dtype if not already
dtype_points = [("x", "<f8"), ("y", "<f8"), ("visible", "?"), ("complete", "?")]
if points_dset.dtype.kind != "V":
points_dset = cast_as_compound(points_dset, dtype_points)
if pred_points_dset.dtype.kind != "V":
pred_points_dset = cast_as_compound(pred_points_dset, dtype_points)

dtype_instances = [
("instance_id", "<i8"),
("instance_type", "u1"),
("frame_id", "<u8"),
("skeleton", "<u4"),
("track", "<i4"),
("from_predicted", "<i8"),
("score", "<f4"),
("point_id_start", "<u8"),
("point_id_end", "<u8"),
]
if instances_dset.dtype.kind != "V":
instances_dset = cast_as_compound(instances_dset, dtype_instances)

dtype_frames = [
("frame_id", "<u8"),
("video", "<u4"),
("frame_idx", "<u8"),
("instance_id_start", "<u8"),
("instance_id_end", "<u8"),
]
if frames_dset.dtype.kind != "V":
frames_dset = cast_as_compound(frames_dset, dtype_frames)

# Rather than instantiate a bunch of Point\PredictedPoint objects, we will use
# inplace numpy recarrays. This will save a lot of time and memory when reading
# things in.
Expand Down Expand Up @@ -283,9 +325,10 @@ def write(
if append and "json" in meta_group.attrs:

# Otherwise, we need to read the JSON and append to the lists
old_labels = labels_json.LabelsJsonAdaptor.from_json_data(
meta_group.attrs["json"].tobytes().decode()
)
json = meta_group.attrs["json"]
if not isinstance(json, str):
json = json.tobytes().decode()
old_labels = labels_json.LabelsJsonAdaptor.from_json_data(json)

# A function to join to list but only include new non-dupe entries
# from the right hand list.
Expand Down

0 comments on commit 14b5b78

Please sign in to comment.