-
Notifications
You must be signed in to change notification settings - Fork 97
/
instance.py
1808 lines (1479 loc) · 62.4 KB
/
instance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Data structures for all labeled data contained with a SLEAP project.
The relationships between objects in this module:
* A `LabeledFrame` can contain zero or more `Instance`s
(and `PredictedInstance` objects).
* `Instance` objects (and `PredictedInstance` objects) have `PointArray`
(or `PredictedPointArray`).
* `Instance` (`PredictedInstance`) can be associated with a `Track`
* A `PointArray` (or `PredictedPointArray`) contains zero or more
`Point` objects (or `PredictedPoint` objectss), ideally as many as
there are in the associated :class:`Skeleton` although these can get
out of sync if the skeleton is manipulated.
"""
import math
import numpy as np
import cattr
from copy import copy
from typing import Dict, List, Optional, Union, Tuple, ForwardRef
from numpy.lib.recfunctions import structured_to_unstructured
import sleap
from sleap.skeleton import Skeleton, Node
from sleap.io.video import Video
import attr
class Point(np.record):
"""
A labelled point and any metadata associated with it.
Args:
x: The horizontal pixel location of point within image frame.
y: The vertical pixel location of point within image frame.
visible: Whether point is visible in the labelled image or not.
complete: Has the point been verified by the user labeler.
"""
# Define the dtype from the point class attributes plus some
# additional fields we will use to relate point to instances and
# nodes.
dtype = np.dtype([("x", "f8"), ("y", "f8"), ("visible", "?"), ("complete", "?")])
def __new__(
cls,
x: float = math.nan,
y: float = math.nan,
visible: bool = True,
complete: bool = False,
) -> "Point":
# HACK: This is a crazy way to instantiate at new Point but I can't figure
# out how recarray does it. So I just use it to make matrix of size 1 and
# index in to get the np.record/Point
# All of this is a giant hack so that Point(x=2,y=3) works like expected.
val = PointArray(1)
val[0] = (x, y, visible, complete)
val = val[0]
# val.x = x
# val.y = y
# val.visible = visible
# val.complete = complete
return val
def __str__(self) -> str:
return f"({self.x}, {self.y})"
def isnan(self) -> bool:
"""
Are either of the coordinates a NaN value.
Returns:
True if x or y is NaN, False otherwise.
"""
return math.isnan(self.x) or math.isnan(self.y)
def numpy() -> np.ndarray:
"""Return the point as a numpy array."""
return np.array([self.x, self.y])
# This turns PredictedPoint into an attrs class. Defines comparators for
# us and generaly makes it behave better. Crazy that this works!
Point = attr.s(these={name: attr.ib() for name in Point.dtype.names}, init=False)(Point)
class PredictedPoint(Point):
"""
A predicted point is an output of the inference procedure.
It has all the properties of a labeled point, plus a score.
Args:
x: The horizontal pixel location of point within image frame.
y: The vertical pixel location of point within image frame.
visible: Whether point is visible in the labelled image or not.
complete: Has the point been verified by the user labeler.
score: The point-level prediction score.
"""
# Define the dtype from the point class attributes plus some
# additional fields we will use to relate point to instances and
# nodes.
dtype = np.dtype(
[("x", "f8"), ("y", "f8"), ("visible", "?"), ("complete", "?"), ("score", "f8")]
)
def __new__(
cls,
x: float = math.nan,
y: float = math.nan,
visible: bool = True,
complete: bool = False,
score: float = 0.0,
) -> "PredictedPoint":
# HACK: This is a crazy way to instantiate at new Point but I can't figure
# out how recarray does it. So I just use it to make matrix of size 1 and
# index in to get the np.record/Point
# All of this is a giant hack so that Point(x=2,y=3) works like expected.
val = PredictedPointArray(1)
val[0] = (x, y, visible, complete, score)
val = val[0]
# val.x = x
# val.y = y
# val.visible = visible
# val.complete = complete
# val.score = score
return val
@classmethod
def from_point(cls, point: Point, score: float = 0.0) -> "PredictedPoint":
"""
Create a PredictedPoint from a Point
Args:
point: The point to copy all data from.
score: The score for this predicted point.
Returns:
A scored point based on the point passed in.
"""
return cls(**{**Point.asdict(point), "score": score})
# This turns PredictedPoint into an attrs class. Defines comparators for
# us and generaly makes it behave better. Crazy that this works!
PredictedPoint = attr.s(
these={name: attr.ib() for name in PredictedPoint.dtype.names}, init=False
)(PredictedPoint)
class PointArray(np.recarray):
"""
PointArray is a sub-class of numpy recarray which stores
Point objects as records.
"""
_record_type = Point
def __new__(
subtype,
shape,
buf=None,
offset=0,
strides=None,
formats=None,
names=None,
titles=None,
byteorder=None,
aligned=False,
order="C",
) -> "PointArray":
dtype = subtype._record_type.dtype
if dtype is not None:
descr = np.dtype(dtype)
else:
descr = np.format_parser(formats, names, titles, aligned, byteorder)._descr
if buf is None:
self = np.ndarray.__new__(
subtype, shape, (subtype._record_type, descr), order=order
)
else:
self = np.ndarray.__new__(
subtype,
shape,
(subtype._record_type, descr),
buffer=buf,
offset=offset,
strides=strides,
order=order,
)
return self
def __array_finalize__(self, obj):
"""
Override :method:`np.recarray.__array_finalize__()`.
Overide __array_finalize__ on recarray because it converting the
dtype of any np.void subclass to np.record, we don't want this.
"""
pass
@classmethod
def make_default(cls, size: int) -> "PointArray":
"""
Construct a point array where points are all set to default.
The constructed :class:`PointArray` will have specified size
and each value in the array is assigned the default values for
a :class:`Point``.
Args:
size: The number of points to allocate.
Returns:
A point array with all elements set to Point()
"""
p = cls(size)
p[:] = cls._record_type()
return p
def __getitem__(self, indx: int) -> "Point":
"""Get point by its index in the array."""
obj = super(np.recarray, self).__getitem__(indx)
# copy behavior of getattr, except that here
# we might also be returning a single element
if isinstance(obj, np.ndarray):
if obj.dtype.fields:
obj = obj.view(type(self))
# if issubclass(obj.dtype.type, numpy.void):
# return obj.view(dtype=(self.dtype.type, obj.dtype))
return obj
else:
return obj.view(type=np.ndarray)
else:
# return a single element
return obj
@classmethod
def from_array(cls, a: "PointArray") -> "PointArray":
"""
Converts a :class:`PointArray` (or child) to a new instance.
This will convert an object to the same type as itself,
so a :class:`PredictedPointArray` will result in the same.
Uses the default attribute values for new array.
Args:
a: The array to convert.
Returns:
A :class:`PointArray` or :class:`PredictedPointArray` with
the same points as a.
"""
v = cls.make_default(len(a))
for field in Point.dtype.names:
v[field] = a[field]
return v
class PredictedPointArray(PointArray):
"""
PredictedPointArray is analogous to PointArray except for predicted
points.
"""
_record_type = PredictedPoint
@classmethod
def to_array(cls, a: "PredictedPointArray") -> "PointArray":
"""
Convert a PredictedPointArray to a normal PointArray.
Args:
a: The array to convert.
Returns:
The converted array.
"""
v = PointArray.make_default(len(a))
for field in Point.dtype.names:
v[field] = a[field]
return v
@attr.s(slots=True, eq=False, order=False)
class Track:
"""
A track object is associated with a set of animal/object instances
across multiple frames of video. This allows tracking of unique
entities in the video over time and space.
Args:
spawned_on: The video frame that this track was spawned on.
name: A name given to this track for identifying purposes.
"""
spawned_on: int = attr.ib(default=0, converter=int)
name: str = attr.ib(default="", converter=str)
def matches(self, other: "Track"):
"""
Check if two tracks match by value.
Args:
other: The other track to check
Returns:
True if they match, False otherwise.
"""
return attr.asdict(self) == attr.asdict(other)
# NOTE:
# Instance cannot be a slotted class at the moment. This is because it creates
# attributes _frame and _point_array_cache after init. These are private variables
# that are created in post init so they are not serialized.
@attr.s(eq=False, order=False, slots=True, repr=False, str=False)
class Instance:
"""This class represents a labeled instance.
Args:
skeleton: The skeleton that this instance is associated with.
points: A dictionary where keys are skeleton node names and
values are Point objects. Alternatively, a point array whose
length and order matches skeleton.nodes.
track: An optional multi-frame object track associated with
this instance. This allows individual animals/objects to be
tracked across frames.
from_predicted: The predicted instance (if any) that this was
copied from.
frame: A back reference to the :class:`LabeledFrame` that this
:class:`Instance` belongs to. This field is set when
instances are added to :class:`LabeledFrame` objects.
"""
skeleton: Skeleton = attr.ib()
track: Track = attr.ib(default=None)
from_predicted: Optional["PredictedInstance"] = attr.ib(default=None)
_points: PointArray = attr.ib(default=None)
_nodes: List = attr.ib(default=None)
frame: Union["LabeledFrame", None] = attr.ib(default=None)
# The underlying Point array type that this instances point array should be.
_point_array_type = PointArray
@from_predicted.validator
def _validate_from_predicted_(
self, attribute, from_predicted: Optional["PredictedInstance"]
):
"""Validation method called by attrs.
Checks that from_predicted is None or :class:`PredictedInstance`
Args:
attribute: Attribute being validated; not used.
from_predicted: Value being validated.
Raises:
TypeError: If from_predicted is anything other than None
or a `PredictedInstance`.
"""
if from_predicted is not None and type(from_predicted) != PredictedInstance:
raise TypeError(
f"Instance.from_predicted type must be PredictedInstance (not "
f"{type(from_predicted)})"
)
@_points.validator
def _validate_all_points(self, attribute, points: Union[dict, PointArray]):
"""Validation method called by attrs.
Checks that all the _points defined for the skeleton are found
in the skeleton.
Args:
attribute: Attribute being validated; not used.
points: Either dict of points or PointArray
If dict, keys should be node names.
Raises:
ValueError: If a point is associated with a skeleton node
name that doesn't exist.
Returns:
None
"""
if type(points) is dict:
is_string_dict = set(map(type, points)) == {str}
if is_string_dict:
for node_name in points.keys():
if not self.skeleton.has_node(node_name):
raise KeyError(
f"There is no node named {node_name} in {self.skeleton}"
)
elif isinstance(points, PointArray):
if len(points) != len(self.skeleton.nodes):
raise ValueError(
"PointArray does not have the same number of rows as skeleton "
"nodes."
)
def __attrs_post_init__(self):
"""Method called by attrs after __init__().
Initializes points if none were specified when creating object,
caches list of nodes so what we can still find points in array
if the `Skeleton` changes.
Args:
None
Raises:
ValueError: If object has no `Skeleton`.
"""
if self.skeleton is None:
raise ValueError("No skeleton set for Instance")
# If the user did not pass a points list initialize a point array for future
# points.
if self._points is None or len(self._points) == 0:
# Initialize an empty point array that is the size of the skeleton.
self._points = self._point_array_type.make_default(len(self.skeleton.nodes))
else:
if type(self._points) is dict:
parray = self._point_array_type.make_default(len(self.skeleton.nodes))
Instance._points_dict_to_array(self._points, parray, self.skeleton)
self._points = parray
# Now that we've validated the points, cache the list of nodes
# in the skeleton since the PointArray indexing will be linked
# to this list even if nodes are removed from the skeleton.
self._nodes = self.skeleton.nodes
@staticmethod
def _points_dict_to_array(
points: Dict[Union[str, Node], Point], parray: PointArray, skeleton: Skeleton
):
"""Set values in given :class:`PointsArray` from dictionary.
Args:
points: The dictionary of points. Keys can be either node
names or :class:`Node`s, values are :class:`Point`s.
parray: The :class:`PointsArray` which is being updated.
skeleton: The :class:`Skeleton` which contains the nodes
referenced in the dictionary of points.
Raises:
ValueError: If dictionary keys are not either all strings
or all :class:`Node`s.
"""
# Check if the dict contains all strings
is_string_dict = set(map(type, points)) == {str}
# Check if the dict contains all Node objects
is_node_dict = set(map(type, points)) == {Node}
# If the user fed in a dict whose keys are strings, these are node names,
# convert to node indices so we don't break references to skeleton nodes
# if the node name is relabeled.
if points and is_string_dict:
points = {skeleton.find_node(name): point for name, point in points.items()}
if not is_string_dict and not is_node_dict:
raise ValueError(
"points dictionary must be keyed by either strings "
+ "(node names) or Nodes."
)
# Get rid of the points dict and replace with equivalent point array.
for node, point in points.items():
# Convert PredictedPoint to Point if Instance
if type(parray) == PointArray and type(point) == PredictedPoint:
point = Point(
x=point.x, y=point.y, visible=point.visible, complete=point.complete
)
try:
parray[skeleton.node_to_index(node)] = point
# parray[skeleton.node_to_index(node.name)] = point
except:
pass
def _node_to_index(self, node: Union[str, Node]) -> int:
"""Helper method to get the index of a node from its name.
Args:
node: Node name or :class:`Node` object.
Returns:
The index of the node on skeleton graph.
"""
return self.skeleton.node_to_index(node)
def __getitem__(
self,
node: Union[List[Union[str, Node, int]], Union[str, Node, int], np.ndarray],
) -> Union[List[Point], Point, np.ndarray]:
"""Get the Points associated with particular skeleton node(s).
Args:
node: A single node or list of nodes within the skeleton
associated with this instance.
Raises:
KeyError: If node cannot be found in skeleton.
Returns:
Either a single point (if a single node given), or
a list of points (if a list of nodes given) corresponding
to each node.
"""
self._fix_array()
# If the node is a list of nodes, use get item recursively and return a list of
# _points.
if isinstance(node, (list, tuple, np.ndarray)):
pts = []
for n in node:
pts.append(self.__getitem__(n))
if isinstance(node, np.ndarray):
return np.array([[pt.x, pt.y] for pt in pts])
else:
return pts
if isinstance(node, (Node, str)):
try:
node = self._node_to_index(node)
except ValueError:
raise KeyError(
f"The underlying skeleton ({self.skeleton}) has no node '{node}'"
)
return self._points[node]
def __contains__(self, node: Union[str, Node, int]) -> bool:
"""Whether this instance has a point with the specified node.
Args:
node: Node name or :class:`Node` object.
Returns:
bool: True if the point with the node name specified has a
point in this instance.
"""
if isinstance(node, Node):
node = node.name
if isinstance(node, str):
if node not in self.skeleton:
return False
node = self._node_to_index(node)
# If the points are nan, then they haven't been allocated.
return not self._points[node].isnan()
def __setitem__(
self,
node: Union[List[Union[str, Node, int]], Union[str, Node, int], np.ndarray],
value: Union[List[Point], Point, np.ndarray],
):
"""Set the point(s) for given node(s).
Args:
node: Either node (by name or `Node`) or list of nodes.
value: Either `Point` or list of `Point`s.
Raises:
IndexError: If lengths of lists don't match, or if exactly
one of the inputs is a list.
KeyError: If skeleton does not have (one of) the node(s).
"""
self._fix_array()
# Make sure node and value, if either are lists, are of compatible size
if isinstance(node, (list, np.ndarray)):
if not isinstance(value, (list, np.ndarray)) or len(value) != len(node):
raise IndexError(
"Node list for indexing must be same length and value list."
)
for n, v in zip(node, value):
self.__setitem__(n, v)
else:
if isinstance(node, (Node, str)):
try:
node_idx = self._node_to_index(node)
except ValueError:
raise KeyError(
f"The skeleton ({self.skeleton}) has no node '{node}'."
)
else:
node_idx = node
if not isinstance(value, Point):
if hasattr(value, "__len__") and len(value) == 2:
value = Point(x=value[0], y=value[1])
else:
raise ValueError(
"Instance point values must be (x, y) coordinates."
)
self._points[node_idx] = value
def __delitem__(self, node: Union[str, Node]):
"""Delete node key and points associated with that node.
Args:
node: Node name or :class:`Node` object.
Raises:
KeyError: If skeleton does not have the node.
Returns:
None
"""
try:
node_idx = self._node_to_index(node)
self._points[node_idx].x = math.nan
self._points[node_idx].y = math.nan
except ValueError:
raise KeyError(
f"The underlying skeleton ({self.skeleton}) has no node '{node}'"
)
def __repr__(self) -> str:
"""Return string representation of this object."""
pts = []
for node, pt in self.nodes_points:
pts.append(f"{node.name}: ({pt.x:.1f}, {pt.y:.1f})")
pts = ", ".join(pts)
return (
"Instance("
f"video={self.video}, "
f"frame_idx={self.frame_idx}, "
f"points=[{pts}], "
f"track={self.track}"
")"
)
def matches(self, other: "Instance") -> bool:
"""Whether two instances match by value.
Checks the types, points, track, and frame index.
Args:
other: The other :class:`Instance`.
Returns:
True if match, False otherwise.
"""
if type(self) is not type(other):
return False
if list(self.points) != list(other.points):
return False
if not self.skeleton.matches(other.skeleton):
return False
if self.track and other.track and not self.track.matches(other.track):
return False
if self.track and not other.track or not self.track and other.track:
return False
# Make sure the frame indices match
if not self.frame_idx == other.frame_idx:
return False
return True
@property
def nodes(self) -> Tuple[Node, ...]:
"""Return nodes that have been labelled for this instance."""
self._fix_array()
return tuple(
self._nodes[i]
for i, point in enumerate(self._points)
if not point.isnan() and self._nodes[i] in self.skeleton.nodes
)
@property
def nodes_points(self) -> List[Tuple[Node, Point]]:
"""Return a list of (node, point) tuples for all labeled points."""
names_to_points = dict(zip(self.nodes, self.points))
return names_to_points.items()
@property
def points(self) -> Tuple[Point, ...]:
"""Return a tuple of labelled points, in the order they were labelled."""
self._fix_array()
return tuple(point for point in self._points if not point.isnan())
def _fix_array(self):
"""Fix PointArray after nodes have been added or removed.
This updates the PointArray as required by comparing the cached
list of nodes to the nodes in the `Skeleton` object (which may
have changed).
"""
# Check if cached skeleton nodes are different than current nodes
if self._nodes != self.skeleton.nodes:
# Create new PointArray (or PredictedPointArray)
cls = type(self._points)
new_array = cls.make_default(len(self.skeleton.nodes))
# Add points into new array
for i, node in enumerate(self._nodes):
if node in self.skeleton.nodes:
new_array[self.skeleton.nodes.index(node)] = self._points[i]
# Update points and nodes for this instance
self._points = new_array
self._nodes = self.skeleton.nodes
def get_points_array(
self, copy: bool = True, invisible_as_nan: bool = False, full: bool = False
) -> Union[np.ndarray, np.recarray]:
"""Return the instance's points in array form.
Args:
copy: If True, the return a copy of the points array as an ndarray.
If False, return a view of the underlying recarray.
invisible_as_nan: Should invisible points be marked as NaN.
If copy is False, then invisible_as_nan is ignored since we
don't want to set invisible points to NaNs in original data.
full: If True, return all data for points. Otherwise, return just
the x and y coordinates.
Returns:
Either a recarray (if copy is False) or an ndarray (if copy True).
The order of the rows corresponds to the ordering of the skeleton
nodes. Any skeleton node not defined will have NaNs present.
Columns in recarray are accessed by name, e.g., ["x"], ["y"].
Columns in ndarray are accessed by number. The order matches
the order in `Point.dtype` or `PredictedPoint.dtype`.
"""
self._fix_array()
if not copy:
if full:
return self._points
else:
return self._points[["x", "y"]]
else:
if full:
parray = structured_to_unstructured(self._points)
else:
parray = structured_to_unstructured(self._points[["x", "y"]])
# Note that invisible_as_nan assumes copy is True.
if invisible_as_nan:
parray[~self._points.visible] = math.nan
return parray
def fill_missing(
self, max_x: Optional[float] = None, max_y: Optional[float] = None
):
"""Add points for skeleton nodes that are missing in the instance.
This is useful when modifying the skeleton so the nodes appears in the GUI.
Args:
max_x: If specified, make sure points are not added outside of valid range.
max_y: If specified, make sure points are not added outside of valid range.
"""
self._fix_array()
y1, x1, y2, x2 = self.bounding_box
y1, x1 = np.nanmax([y1, 0]), np.nanmax([x1, 0])
if max_x is not None:
x2 = np.nanmin([x2, max_x])
if max_y is not None:
y2 = np.nanmin([y2, max_y])
w, h = y2 - y1, x2 - x1
for node in self.skeleton.nodes:
if node not in self.nodes or self[node].isnan():
off = np.array([w, h]) * np.random.rand(2)
x, y = off + np.array([x1, y1])
y, x = max(y, 0), max(x, 0)
if max_x is not None:
x = min(x, max_x)
if max_y is not None:
y = min(y, max_y)
self[node] = Point(x=x, y=y, visible=False)
@property
def points_array(self) -> np.ndarray:
"""Return array of x and y coordinates for visible points.
Row in array corresponds to order of points in skeleton. Invisible points will
be denoted by NaNs.
Returns:
A numpy array of of shape `(n_nodes, 2)` point coordinates.
"""
return self.get_points_array(invisible_as_nan=True)
def numpy(self) -> np.ndarray:
"""Return the instance node coordinates as a numpy array.
Alias for `points_array`.
Returns:
Array of shape `(n_nodes, 2)` of dtype `float32` containing the coordinates
of the instance's nodes. Missing/not visible nodes will be replaced with
`NaN`.
"""
return self.points_array
def transform_points(self, transformation_matrix):
"""Apply affine transformation matrix to points in the instance.
Args:
transformation_matrix: Affine transformation matrix as a numpy array of
shape `(3, 3)`.
"""
points = self.get_points_array(copy=True, full=False, invisible_as_nan=False)
if transformation_matrix.shape[1] == 3:
rotation = transformation_matrix[:, :2]
translation = transformation_matrix[:, 2]
transformed = points @ rotation.T + translation
else:
transformed = points @ transformation_matrix.T
self._points["x"] = transformed[:, 0]
self._points["y"] = transformed[:, 1]
@property
def centroid(self) -> np.ndarray:
"""Return instance centroid as an array of `(x, y)` coordinates
Notes:
This computes the centroid as the median of the visible points.
"""
points = self.points_array
centroid = np.nanmedian(points, axis=0)
return centroid
@property
def bounding_box(self) -> np.ndarray:
"""Return bounding box containing all points in `[y1, x1, y2, x2]` format."""
points = self.points_array
if np.isnan(points).all():
return np.array([np.nan, np.nan, np.nan, np.nan])
bbox = np.concatenate(
[np.nanmin(points, axis=0)[::-1], np.nanmax(points, axis=0)[::-1]]
)
return bbox
@property
def midpoint(self) -> np.ndarray:
"""Return the center of the bounding box of the instance points."""
y1, x1, y2, x2 = self.bounding_box
return np.array([(x2 - x1) / 2, (y2 - y1) / 2])
@property
def n_visible_points(self) -> int:
"""Return the number of visible points in this instance."""
n = 0
for p in self.points:
if p.visible:
n += 1
return n
def __len__(self) -> int:
"""Return the number of visible points in this instance."""
return self.n_visible_points
@property
def video(self) -> Optional[Video]:
"""Return the video of the labeled frame this instance is associated with."""
if self.frame is None:
return None
else:
return self.frame.video
@property
def frame_idx(self) -> Optional[int]:
"""Return the index of the labeled frame this instance is associated with."""
if self.frame is None:
return None
else:
return self.frame.frame_idx
@classmethod
def from_pointsarray(
cls, points: np.ndarray, skeleton: Skeleton, track: Optional[Track] = None
) -> "Instance":
"""Create an instance from an array of points.
Args:
points: A numpy array of shape `(n_nodes, 2)` and dtype `float32` that
contains the points in (x, y) coordinates of each node. Missing nodes
should be represented as `NaN`.
skeleton: A `sleap.Skeleton` instance with `n_nodes` nodes to associate with
the instance.
track: Optional `sleap.Track` object to associate with the instance.
Returns:
A new `Instance` object.
"""
predicted_points = dict()
for point, node_name in zip(points, skeleton.node_names):
if np.isnan(point).any():
continue
predicted_points[node_name] = Point(x=point[0], y=point[1])
return cls(points=predicted_points, skeleton=skeleton, track=track)
@classmethod
def from_numpy(
cls, points: np.ndarray, skeleton: Skeleton, track: Optional[Track] = None
) -> "Instance":
"""Create an instance from a numpy array.
Args:
points: A numpy array of shape `(n_nodes, 2)` and dtype `float32` that
contains the points in (x, y) coordinates of each node. Missing nodes
should be represented as `NaN`.
skeleton: A `sleap.Skeleton` instance with `n_nodes` nodes to associate with
the instance.
track: Optional `sleap.Track` object to associate with the instance.
Returns:
A new `Instance` object.
Notes:
This is an alias for `Instance.from_pointsarray()`.
"""
return cls.from_pointsarray(points, skeleton, track=track)
def _merge_nodes_data(self, base_node: str, merge_node: str):
"""Copy point data from one node to another.
Args:
base_node: Name of node that will be merged into.
merge_node: Name of node that will be removed after merge.
Notes:
This is used when merging skeleton nodes and should not be called directly.
"""
base_pt = self[base_node]
merge_pt = self[merge_node]
if merge_pt.isnan():
return
if base_pt.isnan() or not base_pt.visible:
base_pt.x = merge_pt.x
base_pt.y = merge_pt.y
base_pt.visible = merge_pt.visible
base_pt.complete = merge_pt.complete
if hasattr(base_pt, "score"):
base_pt.score = merge_pt.score
@attr.s(eq=False, order=False, slots=True, repr=False, str=False)
class PredictedInstance(Instance):
"""
A predicted instance is an output of the inference procedure.
Args:
score: The instance-level grouping prediction score.