Skip to content

Commit

Permalink
implement __hash__ for Sequence class
Browse files Browse the repository at this point in the history
  • Loading branch information
NickleDave committed May 7, 2019
1 parent 5bfeaff commit a9bbbc9
Showing 1 changed file with 60 additions and 26 deletions.
86 changes: 60 additions & 26 deletions src/crowsetta/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,37 +80,74 @@ def __init__(self,
raise ValueError(f"File for segments, '{file_from_segments}', "
f"does not match file, '{file}'.")

super().__setattr__('segments', segments)
super().__setattr__('onsets_s', onsets_s)
super().__setattr__('offsets_s', offsets_s)
super().__setattr__('onsets_Hz', onsets_Hz)
super().__setattr__('offsets_Hz', offsets_Hz)
super().__setattr__('labels', labels)
super().__setattr__('file', file)
super().__setattr__('_segments', segments)
super().__setattr__('_onsets_s', onsets_s)
super().__setattr__('_offsets_s', offsets_s)
super().__setattr__('_onsets_Hz', onsets_Hz)
super().__setattr__('_offsets_Hz', offsets_Hz)
super().__setattr__('_labels', labels)
super().__setattr__('_file', file)

@property
def segments(self):
return self._segments

@property
def onsets_s(self):
return self._onsets_s

@property
def offsets_s(self):
return self._offsets_s

@property
def onsets_Hz(self):
return self._onsets_Hz

@property
def offsets_Hz(self):
return self._offsets_Hz

@property
def labels(self):
return self._labels

@property
def file(self):
return self._file

def __hash__(self):
return hash(
(self._segments,
self._onsets_s,
self._offsets_s,
self._onsets_Hz,
self._offsets_Hz,
self._labels,
self._file)
)

def __repr__(self):
return f"<Sequence with {len(self.segments)} segments>"

def __eq__(self, other):
if self.__class__ == other.__class__:
eq = []
for attr in ['segments', 'labels', 'file',
'onsets_s', 'offsets_s', 'onsets_Hz', 'offsets_Hz']:
self_attr = getattr(self, attr)
other_attr = getattr(other, attr)
if type(self_attr) == np.ndarray:
eq.append(np.array_equal(self_attr, other_attr))
else:
eq.append(self_attr == other_attr)

if all(eq):
return True
if not isinstance(other, Sequence):
return False

eq = []
for attr in ['_segments', '_labels', '_file',
'_onsets_s', '_offsets_s', '_onsets_Hz', '_offsets_Hz']:
self_attr = getattr(self, attr)
other_attr = getattr(other, attr)
if type(self_attr) == np.ndarray:
eq.append(np.array_equal(self_attr, other_attr))
else:
return False
eq.append(self_attr == other_attr)

if all(eq):
return True
else:
raise TypeError("can only test for equality between two Sequences, not "
f"between a Sequence and {type(other)}")
return False

def __ne__(self, other):
if self.__class__ == other.__class__:
Expand All @@ -134,9 +171,6 @@ def __gt__(self, other):
def __ge__(self, other):
raise NotImplementedError

def __hash__(self):
raise NotImplementedError

@staticmethod
def _convert_labels(labels):
if type(labels) == str:
Expand Down

0 comments on commit a9bbbc9

Please sign in to comment.