Skip to content

Commit

Permalink
MRG: #176 from vocalpy/fix-register-format
Browse files Browse the repository at this point in the history
Fix `crowsetta.formats.register_format`
  • Loading branch information
NickleDave committed May 15, 2022
2 parents 09d770f + 85b7f5a commit 6509fac
Show file tree
Hide file tree
Showing 10 changed files with 313 additions and 110 deletions.
177 changes: 115 additions & 62 deletions doc/notebooks/batlab2seq.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,116 @@
import pathlib
from typing import ClassVar

import attr
import numpy as np
from scipy.io import loadmat

from crowsetta.sequence import Sequence


def batlab2seq(mat_file):
"""unpack BatLAB annotation into list of Sequence objects
example of a function that unpacks annotation from
a complicated data structure and returns the necessary
data as a Sequence object
Parameters
----------
mat_file : str
filename of .mat file created by BatLAB
Returns
-------
seq_list : list
of Sequence objects
"""
mat = loadmat(mat_file, squeeze_me=True)
seq_list = []
# annotation structure loads as a Python dictionary with two keys
# one maps to a list of filenames,
# and the other to a Numpy array where each element is the annotation
# coresponding to the filename at the same index in the list.
# We can iterate over both by using the zip() function.
for filename, annotation in zip(mat['filenames'], mat['annotations']):
# below, .tolist() does not actually create a list,
# instead gets ndarray out of a zero-length ndarray of dtype=object.
# This is just weirdness that results from loading complicated data
# structure in .mat file.
seg_start_times = annotation['segFileStartTimes'].tolist()
seg_end_times = annotation['segFileEndTimes'].tolist()
seg_types = annotation['segType'].tolist()
if type(seg_types) == int:
# this happens when there's only one syllable in the file
# with only one corresponding label
seg_types = np.asarray([seg_types]) # so make it a one-element list
elif type(seg_types) == np.ndarray:
# this should happen whenever there's more than one label
pass
else:
# something unexpected happened
raise ValueError("Unable to load labels from {}, because "
"the segType parsed as type {} which is "
"not recognized.".format(wav_filename,
type(seg_types)))
samp_freq = annotation['fs'].tolist()
seg_start_times_Hz = np.round(seg_start_times * samp_freq).astype(int)
seg_end_times_Hz = np.round(seg_end_times * samp_freq).astype(int)

seq = Sequence.from_keyword(file=filename,
labels=seg_types,
onsets_s=seg_start_times,
offsets_s=seg_end_times,
onset_inds=seg_start_times_Hz,
offset_inds=seg_end_times_Hz)
seq_list.append(seq)
return seq_list
import scipy.io

from crowsetta import Sequence, Annotation
from crowsetta.typing import PathLike
import crowsetta


@crowsetta.formats.register_format
@crowsetta.interface.SeqLike.register
@attr.define
class Batlab:
"""Example custom annotation format"""
name: ClassVar[str] = 'example-custom-format'
ext: ClassVar[str] = '.mat'

annotations: np.ndarray = attr.field(eq=attr.cmp_using(eq=np.array_equal))
audio_paths: np.ndarray = attr.field(eq=attr.cmp_using(eq=np.array_equal))
mat_path: pathlib.Path = attr.field(converter=pathlib.Path)

@classmethod
def from_file(cls,
mat_path: PathLike):
"""load BatLAB annotations from .mat file
Parameters
----------
mat_path : str, pathlib.Path
"""
mat_path = pathlib.Path(mat_path)
crowsetta.validation.validate_ext(mat_path, extension=cls.ext)

annot_mat = scipy.io.loadmat(mat_path, squeeze_me=True)

audio_paths = annot_mat['filenames']
annotations = annot_mat['annotations']
if len(audio_paths) != len(annotations):
raise ValueError(
f'list of filenames and list of annotations in {mat_path} do not have the same length'
)

return cls(annotations=annotations,
audio_paths=audio_paths,
annot_path=mat_path)


def to_seq(self):
"""unpack BatLAB annotation into list of Sequence objects
example of a function that unpacks annotation from
a complicated data structure and returns the necessary
data as a Sequence object
Returns
-------
seqs : list
of Sequence objects
"""
seqs = []
# annotation structure loads as a Python dictionary with two keys
# one maps to a list of filenames,
# and the other to a Numpy array where each element is the annotation
# coresponding to the filename at the same index in the list.
# We can iterate over both by using the zip() function.
for filename, annotation in zip(self.audio_paths, self.annotations):
# below, .tolist() does not actually create a list,
# instead gets ndarray out of a zero-length ndarray of dtype=object.
# This is just weirdness that results from loading complicated data
# structure in .mat file.
onsets_s = annotation['segFileStartTimes'].tolist()
offsets_s = annotation['segFileEndTimes'].tolist()
labels = annotation['segType'].tolist()
if type(labels) == int:
# this happens when there's only one syllable in the file
# with only one corresponding label
seg_types = np.asarray([seg_types]) # so make it a one-element list
elif type(seg_types) == np.ndarray:
# this should happen whenever there's more than one label
pass
else:
# something unexpected happened
raise ValueError("Unable to load labels from {}, because "
"the segType parsed as type {} which is "
"not recognized.".format(audio_path,
type(seg_types)))
samp_freq = annotation['fs'].tolist()

seq = Sequence.from_keyword(labels=labels,
onsets_s=onsets_s,
offsets_s=offsets_s)
seqs.append(seq)
return seqs

def to_annot(self):
"""example of a function that unpacks annotation
and returns the necessary data as a
``crowsetta.Annotation``"""

seqs = self.to_seq()

annot_list = []
# annotation structure loads as a Python dictionary with two keys
# one maps to a list of filenames,
# and the other to a Numpy array where each element is the annotation
# corresponding to the filename at the same index in the list.
# We can iterate over both by using the zip() function.
for filename, seq in zip(self.audio_paths, seqs):
annot_list.append(Annotation(annot_path=self.annot_path,
notated_path=filename,
seq=seq))

return annot_list
3 changes: 1 addition & 2 deletions src/crowsetta/formats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,5 @@ def register_format(format_class):
raise TypeError(
f'format class must be subclass of SeqLike or BBoxLike, but was not: {format_class}'
)
name = format_class.__name__
FORMATS[name] = format_class
FORMATS[format_class.name] = format_class
return format_class
105 changes: 105 additions & 0 deletions tests/data_for_tests/example_custom_format/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import pathlib
from typing import ClassVar

import attr
import numpy as np
import scipy.io

from crowsetta import Sequence, Annotation
from crowsetta.typing import PathLike
import crowsetta

@crowsetta.formats.register_format
@crowsetta.interface.SeqLike.register
@attr.define
class Custom:
"""Example custom annotation format"""
name: ClassVar[str] = 'example-custom-format'
ext: ClassVar[str] = '.mat'

annotations: np.ndarray = attr.field(eq=attr.cmp_using(eq=np.array_equal))
audio_paths: np.ndarray = attr.field(eq=attr.cmp_using(eq=np.array_equal))
annot_path: pathlib.Path = attr.field(converter=pathlib.Path)

@classmethod
def from_file(cls,
annot_path: PathLike) -> 'Self':
"""load annotations from .mat files
Parameters
----------
annot_path: str, pathlib.Path
Path to .mat file with annotations.
"""
annot_path = pathlib.Path(annot_path)
crowsetta.validation.validate_ext(annot_path, extension=cls.ext)

# annotation structure loads as a Python dictionary with two keys
# one maps to a list of filenames,
# and the other to a Numpy record array,
# where each element is the annotation
# corresponding to the filename at the same index in the list.
annot_mat = scipy.io.loadmat(annot_path, squeeze_me=True)
audio_paths = annot_mat['filenames']
annotations = annot_mat['annotations']
if len(audio_paths) != len(annotations):
raise ValueError(
f'list of filenames and list of annotations in {annot_path} do not have the same length'
)

return cls(annotations=annotations,
audio_paths=audio_paths,
annot_path=annot_path)

def to_seq(self):
seqs = []
# annotation structure loads as a Python dictionary with two keys
# one maps to a list of filenames,
# and the other to a Numpy array where each element is the annotation
# coresponding to the filename at the same index in the list.
# We can iterate over both by using the zip() function.
for filename, annotation in zip(self.audio_paths, self.annotations):
# below, .tolist() does not actually create a list,
# instead gets ndarray out of a zero-length ndarray of dtype=object.
# This is just weirdness that results from loading complicated data
# structure in .mat file.
onsets_s = annotation['segFileStartTimes'].tolist()
offsets_s = annotation['segFileEndTimes'].tolist()
labels = annotation['segType'].tolist()
if type(labels) == int:
# this happens when there's only one syllable in the file
# with only one corresponding label
labels = np.asarray([labels]) # so make it a one-element list
elif type(labels) == np.ndarray:
# this should happen whenever there's more than one label
pass
else:
# something unexpected happened
raise ValueError("Unable to load labels from {}, because "
"the segType parsed as type {} which is "
"not recognized.".format(filename,
type(labels)))
seq = Sequence.from_keyword(labels=labels,
onsets_s=onsets_s,
offsets_s=offsets_s)
seqs.append(seq)
return seqs

def to_annot(self):
"""example of a function that unpacks annotation from
a complicated data structure and returns the necessary
data from an Annotation object"""
seqs = self.to_seq()

annot_list = []
# annotation structure loads as a Python dictionary with two keys
# one maps to a list of filenames,
# and the other to a Numpy array where each element is the annotation
# coresponding to the filename at the same index in the list.
# We can iterate over both by using the zip() function.
for filename, seq in zip(self.audio_paths, seqs):
annot_list.append(Annotation(annot_path=self.annot_path,
notated_path=filename,
seq=seq))

return annot_list
44 changes: 0 additions & 44 deletions tests/data_for_tests/example_user_format/example.py

This file was deleted.

4 changes: 2 additions & 2 deletions tests/fixtures/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ def timit_phn_as_generic_seq_csv():
return TIMIT_PHN_AS_GENERIC_SEQ_CSV


EXAMPLE_USER_FORMAT_AS_GENERIC_SEQ_CSV = CSV_ROOT / 'example_user_annotation.csv'
EXAMPLE_CUSTOM_FORMAT_AS_GENERIC_SEQ_CSV = CSV_ROOT / 'example_custom_format.csv'


@pytest.fixture
def example_user_format_as_generic_seq_csv():
return EXAMPLE_USER_FORMAT_AS_GENERIC_SEQ_CSV
return EXAMPLE_CUSTOM_FORMAT_AS_GENERIC_SEQ_CSV


CSV_MISSING_FIELDS_IN_HEADER = CSV_ROOT / 'missing_fields_in_header.csv'
Expand Down
18 changes: 18 additions & 0 deletions tests/scripts/remake_test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
remakes the .csv files used for testing the `'generic-seq'` format
"""
from pathlib import Path
import sys

import pandas as pd

Expand Down Expand Up @@ -53,6 +54,23 @@ def remake_timit_phn_as_generic_seq_csv():
timit_generic_seq.to_file(csv_path=csv_path, basename=True)


def remake_example_custom_format_as_generic_seq_csv():
example_custom_format_dir = TEST_DATA / 'example_custom_format'
sys.path.append(example_custom_format_dir)
import example # registers custom format using `crowsetta.formats.register_format`

annot_path = example_custom_format_dir / 'bird1_annotation.mat'
scribe = crowsetta.Transcriber(format='example-custom-format')
example_ = scribe.from_file(annot_path)
annots = example_.to_annot()
custom_format_generic_seq = crowsetta.formats.seq.GenericSeq(annots=annots)
csv_path = TEST_DATA / 'csv' / 'example_custom_format.csv'
print(
f'saving csv: {csv_path}'
)
custom_format_generic_seq.to_file(csv_path=csv_path, basename=True)


def remake_invalid_fields_in_header_csv(source_csv_path):
df = pd.read_csv(source_csv_path)
df['invalid'] = df['label'].copy()
Expand Down

0 comments on commit 6509fac

Please sign in to comment.