From 5fe2153b607db98a59e5314ee49231df884f1652 Mon Sep 17 00:00:00 2001 From: hadware Date: Fri, 7 Jan 2022 20:09:00 +0100 Subject: [PATCH 1/8] Re-added files from backup branch --- .../audio/pipelines/multilabel_detection.py | 302 ++++++++++++++++++ .../segmentation/voice_type_classification.py | 99 ++++++ 2 files changed, 401 insertions(+) create mode 100644 pyannote/audio/pipelines/multilabel_detection.py create mode 100644 pyannote/audio/tasks/segmentation/voice_type_classification.py diff --git a/pyannote/audio/pipelines/multilabel_detection.py b/pyannote/audio/pipelines/multilabel_detection.py new file mode 100644 index 000000000..427d827c0 --- /dev/null +++ b/pyannote/audio/pipelines/multilabel_detection.py @@ -0,0 +1,302 @@ +# The MIT License (MIT) +# +# Copyright (c) 2017-2020 CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +from dataclasses import dataclass +from functools import reduce +from itertools import chain +from typing import Union, Optional, List, Dict, TYPE_CHECKING, Text + +import numpy as np +from numba.typed import List +from pyannote.core import Annotation, SlidingWindowFeature +from pyannote.metrics.base import BaseMetric +from pyannote.metrics.detection import DetectionPrecisionRecallFMeasure +from pyannote.metrics.identification import IdentificationErrorRate +from pyannote.pipeline.parameter import ParamDict, Uniform +from sortedcontainers import SortedDict + +from pyannote.audio import Inference +from pyannote.audio.core.io import AudioFile +from pyannote.audio.core.pipeline import Pipeline +from .utils import PipelineModel, get_devices, get_model +from ..utils.signal import Binarize + +SpeakerClass = Text +MetaClasses = Dict[SpeakerClass, List[SpeakerClass]] + +if TYPE_CHECKING: + from ..tasks.segmentation.voice_type_classification import VoiceTypeClassification + + +@dataclass +class MultilabelDetectionSpecifications: + classes: List[SpeakerClass] + unions: MetaClasses + intersections: MetaClasses + unions_idx: Optional[SortedDict] = None + intersections_idx: Optional[SortedDict] = None + + def __post_init__(self): + # for each metaclass, mapping metaclass label to vector of its + # classes's ids (used for encoding) + self.unions_idx = self.to_metaclasses_idx(self.unions, + self.classes) + self.intersections_idx = self.to_metaclasses_idx(self.intersections, + self.classes) + + @property + def all_classes(self) -> List[str]: + return (self.classes + + list(self.unions.keys()) + + list(self.intersections.keys())) + + @staticmethod + def to_metaclasses_idx(metaclasses: MetaClasses, classes: List[SpeakerClass]) -> SortedDict: + return SortedDict({ + intersection_label: np.array([classes.index(klass) + for klass in intersection_classes]) + for intersection_label, intersection_classes + in metaclasses.items() + }) + + def derive_unions_encoding(self, one_hot_array: np.ndarray): + arrays: List[np.ndarray] = [] + for label, idx in self.unions_idx.items(): + arrays.append(one_hot_array[:, idx].max(axis=1)) + return np.vstack(arrays).swapaxes(0, 1) + + def derive_intersections_encoding(self, one_hot_array: np.ndarray): + arrays: List[np.ndarray] = [] + for label, idx in self.intersections_idx.items(): + arrays.append(one_hot_array[:, idx].min(axis=1)) + return np.vstack(arrays).swapaxes(0, 1) + + def derive_reference(self, annotation: Annotation) -> Annotation: + derived = annotation.subset(self.classes) + # Adding union labels + for union_label, subclasses in self.unions.items(): + mapping = {k: union_label for k in subclasses} + metalabel_annot = annotation.subset(union_label).rename_labels(mapping=mapping) + derived.update(metalabel_annot.support()) + + # adding intersection labels + for intersect_label, subclasses in self.intersections.items(): + subclasses_tl = [annotation.label_timeline(subclass) for subclass in subclasses] + overlap_tl = reduce(lambda x, y: x.crop(y), subclasses_tl) + derived.update(overlap_tl.to_annotation(intersect_label)) + + return derived + + @classmethod + def from_parameters( + cls, + classes: List[SpeakerClass], # VTC-specific parameter + unions: Optional[MetaClasses] = None, + intersections: Optional[MetaClasses] = None, ) \ + -> 'MultilabelDetectionSpecifications': + if unions is not None: + assert set(chain.from_iterable(unions.values())).issubset(set(classes)) + + if intersections is not None: + assert set(chain.from_iterable(intersections.values())).issubset(set(classes)) + + classes = sorted(list(set(classes))) + return cls(classes, + unions if unions else dict(), + intersections if intersections else dict()) + + +class MultilabelFMeasure(BaseMetric): + """Compute the mean Fscore over all labels + + """ + + @classmethod + def metric_name(cls): + return "AVG[Labels]" + + def __init__(self, mtl_specs: MultilabelDetectionSpecifications, # noqa + collar=0.0, skip_overlap=False, + beta=1., parallel=False, **kwargs): + self.parallel = parallel + self.metric_name_ = self.metric_name() + self.components_ = set(self.metric_components()) + self.reset() + self.collar = collar + self.skip_overlap = skip_overlap + self.beta = beta + self.mtl_specs = mtl_specs + self.submetrics: Dict[str, DetectionPrecisionRecallFMeasure] = { + label: DetectionPrecisionRecallFMeasure(collar=collar, + skip_overlap=skip_overlap, + beta=beta, + **kwargs) + for label in self.mtl_specs.all_classes + } + + def reset(self): + super().reset() + for submetric in self.submetrics.values(): + submetric.reset() + + def compute_components(self, reference: Annotation, hypothesis: Annotation, uem=None, **kwargs): + + details = self.init_components() + reference = self.mtl_specs.derive_reference(reference) + for label, submetric in self.submetrics.items(): + details[label] = submetric(reference=reference.subset([label]), + hypothesis=hypothesis.subset([label]), + uem=uem, + **kwargs) + return details + + def compute_metric(self, detail: Dict[str, float]): + return np.mean(detail.values()) + + def __abs__(self): + return np.mean([abs(submetric) for submetric in self.submetrics.values()]) + + +class MultilabelIER(IdentificationErrorRate): + + def __init__(self, mtl_specs: MultilabelDetectionSpecifications, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.mtl_specs = mtl_specs + + def compute_components(self, reference, hypothesis, uem=None, + collar=None, skip_overlap=None, **kwargs): + # deriving labels + reference = self.mtl_specs.derive_reference(reference) + return super().compute_components(reference, hypothesis, + uem=uem, collar=collar, + skip_overlap=skip_overlap, + **kwargs) + + +class MultilabelDetection(Pipeline): + """""" + + def __init__(self, + segmentation: PipelineModel = "pyannote/vtc", + fscore: bool = False, + **inference_kwargs, + ): + + super().__init__() + + self.segmentation = segmentation + self.fscore = fscore + + # load model and send it to GPU (when available and not already on GPU) + model = get_model(segmentation) + if model.device.type == "cpu": + (segmentation_device,) = get_devices(needs=1) + model.to(segmentation_device) + + task: 'VoiceTypeClassification' = model.task + self.mtl_specs = task.clsf_specs + self.labels = task.clsf_specs.all_classes + self.segmentation_inference_ = Inference(model, **inference_kwargs) + + self.binarize_hparams = ParamDict(**{ + class_name: ParamDict( + onset=Uniform(0., 1.), + offset=Uniform(0., 1.), + min_duration_on=Uniform(0., 2.), + min_duration_off=Uniform(0., 2.), + pad_onset=Uniform(-1., 1.), + pad_offset=Uniform(-1., 1.) + ) for class_name in self.labels + }) + + def initialize(self): + """Initialize pipeline with current set of parameters""" + self.freeze({'binarize_hparams': { + class_name: { + "pad_onset": 0.0, + "pad_offset": 0.0 + } for class_name in self.labels + }}) + self._binarizers = { + class_name: Binarize( + onset=self.binarize_hparams[class_name]["onset"], + offset=self.binarize_hparams[class_name]["offset"], + min_duration_on=self.binarize_hparams[class_name]["min_duration_on"], + min_duration_off=self.binarize_hparams[class_name]["min_duration_off"], + pad_onset=self.binarize_hparams[class_name]["pad_onset"], + pad_offset=self.binarize_hparams[class_name]["pad_offset"]) + for class_name in self.labels + } + + CACHED_ACTIVATIONS = "@multilabel_detection/activations" + + def apply(self, file: AudioFile) -> Annotation: + """Apply voice type classification + + Parameters + ---------- + file : AudioFile + Processed file. + + Returns + ------- + speech : `pyannote.core.Annotation` + Annotated classification. + """ + if self.training: + if self.CACHED_ACTIVATIONS not in file: + file[self.CACHED_ACTIVATIONS] = self.segmentation_inference_(file) + else: + file[self.CACHED_ACTIVATIONS] = self.segmentation_inference_(file) + + # for each class name, add + multilabel_scores: SlidingWindowFeature = file[self.CACHED_ACTIVATIONS] + full_annot = Annotation(uri=file["uri"]) + for class_idx, class_name in enumerate(self.labels): + # selecting scores for only one label + label_scores_array: np.ndarray = multilabel_scores.data[:, class_idx] + # creating a fake "num_classes" dim + label_scores_array = np.expand_dims(label_scores_array, axis=1) + # creating a new sliding window for that label + label_scores = SlidingWindowFeature(label_scores_array, + multilabel_scores.sliding_window) + binarizer: Binarize = self._binarizers[class_name] + label_annot = binarizer(label_scores) + full_annot.update(label_annot) + + return full_annot + + def get_metric(self) -> Union[MultilabelFMeasure, IdentificationErrorRate]: + """Return new instance of identification metric""" + + if self.fscore: + return MultilabelFMeasure(mtl_specs=self.mtl_specs, + collar=0.0, skip_overlap=False) + else: + return MultilabelIER(mtl_specs=self.mtl_specs, + collar=0.0, skip_overlap=False) + + def get_direction(self): + if self.fscore: + return "maximize" + else: + return "minimize" diff --git a/pyannote/audio/tasks/segmentation/voice_type_classification.py b/pyannote/audio/tasks/segmentation/voice_type_classification.py new file mode 100644 index 000000000..0363b10f0 --- /dev/null +++ b/pyannote/audio/tasks/segmentation/voice_type_classification.py @@ -0,0 +1,99 @@ +# MIT License +# +# Copyright (c) 2020-2021 CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +from typing import Tuple, Union, Optional, Text + +import numpy as np +from numba.typed import List +from pyannote.database import Protocol +from torch_audiomentations.core.transforms_interface import BaseWaveformTransform + +from .mixins import SegmentationTaskMixin +from ...core.task import Task, Specifications, Problem, Resolution +from ...pipelines.multilabel_detection import MultilabelDetectionSpecifications, SpeakerClass, MetaClasses + + +class VoiceTypeClassification(SegmentationTaskMixin, Task): + """""" + + ACRONYM = "vtc" + + def __init__( + self, + protocol: Protocol, + classes: List[SpeakerClass], # VTC-specific parameter + unions: Optional[MetaClasses] = None, + intersections: Optional[MetaClasses] = None, + duration: float = 5.0, + warm_up: Union[float, Tuple[float, float]] = 0.0, + balance: Text = None, + weight: Text = None, + batch_size: int = 32, + num_workers: int = None, + pin_memory: bool = False, + augmentation: BaseWaveformTransform = None, + ): + super().__init__( + protocol, + duration=duration, + min_duration=duration, + warm_up=warm_up, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + augmentation=augmentation, + ) + self.balance = balance + self.weight = weight + + self.clsf_specs = MultilabelDetectionSpecifications. \ + from_parameters(classes, unions, intersections) + + # setting up specifications, used to set up the model by pt-lightning + self.specifications = Specifications( + # it is a multi-label classification problem + problem=Problem.MULTI_LABEL_CLASSIFICATION, + # we expect the model to output one prediction + # for the whole chunk + resolution=Resolution.FRAME, + # the model will ingest chunks with that duration (in seconds) + duration=self.duration, + # human-readable names of classes + classes=self.clsf_specs.all_classes + ) + + @property + def chunk_labels(self) -> List[SpeakerClass]: + # Only used by `prepare_chunk`, thus, which doesn't need to know + # about union/intersections. + return self.clsf_specs.classes + + def prepare_y(self, one_hot_y: np.ndarray) -> np.ndarray: + # one_hot_y is of shape (Time, Classes) + metaclasses_one_hots = [] + if self.clsf_specs.unions: + metaclasses_one_hots.append(self.clsf_specs.derive_unions_encoding(one_hot_y)) + if self.clsf_specs.intersections: + metaclasses_one_hots.append(self.clsf_specs.derive_intersections_encoding(one_hot_y)) + + if metaclasses_one_hots: + one_hot_y = np.hstack([one_hot_y] + metaclasses_one_hots) + return np.int64(one_hot_y) From 1d59d446720320830874b1866b4c4ba25106f7da Mon Sep 17 00:00:00 2001 From: hadware Date: Fri, 7 Jan 2022 20:21:08 +0100 Subject: [PATCH 2/8] Re-added to init --- pyannote/audio/pipelines/__init__.py | 2 ++ pyannote/audio/tasks/__init__.py | 1 + 2 files changed, 3 insertions(+) diff --git a/pyannote/audio/pipelines/__init__.py b/pyannote/audio/pipelines/__init__.py index a00d54d46..6380cd43e 100644 --- a/pyannote/audio/pipelines/__init__.py +++ b/pyannote/audio/pipelines/__init__.py @@ -24,10 +24,12 @@ from .resegmentation import Resegmentation from .speaker_diarization import SpeakerDiarization from .voice_activity_detection import VoiceActivityDetection +from .multilabel_detection import MultilabelDetection __all__ = [ "VoiceActivityDetection", "OverlappedSpeechDetection", "SpeakerDiarization", "Resegmentation", + "MultilabelDetection" ] diff --git a/pyannote/audio/tasks/__init__.py b/pyannote/audio/tasks/__init__.py index 3854139ab..0892d8b24 100644 --- a/pyannote/audio/tasks/__init__.py +++ b/pyannote/audio/tasks/__init__.py @@ -25,6 +25,7 @@ from .segmentation.overlapped_speech_detection import ( # isort:skip OverlappedSpeechDetection, ) +from .segmentation.voice_type_classification import VoiceTypeClassification # isort:skip from .segmentation.speaker_tracking import SpeakerTracking # isort:skip From 7ebbc7480f2bfa0e546a2ad6b8c10d8a65c2c26b Mon Sep 17 00:00:00 2001 From: hadware Date: Tue, 11 Jan 2022 21:15:45 +0100 Subject: [PATCH 3/8] Re-added __init__ references, re-added VoiceTypeClassification.yaml default config --- pyannote/audio/cli/train_config/hydra/train.yaml | 4 ++-- .../cli/train_config/task/VoiceTypeClassification.yaml | 9 +++++++++ pyannote/audio/tasks/__init__.py | 2 +- 3 files changed, 12 insertions(+), 3 deletions(-) create mode 100644 pyannote/audio/cli/train_config/task/VoiceTypeClassification.yaml diff --git a/pyannote/audio/cli/train_config/hydra/train.yaml b/pyannote/audio/cli/train_config/hydra/train.yaml index b2cb3bd45..b77197bcc 100644 --- a/pyannote/audio/cli/train_config/hydra/train.yaml +++ b/pyannote/audio/cli/train_config/hydra/train.yaml @@ -35,7 +35,7 @@ help: {optimizer} can be any of the following * adam (default) = Adam optimizer - {trainer} can be any of the following + {trainer} can be any of the following * fast_dev_run for debugging * default (default) for training the model @@ -70,7 +70,7 @@ help: 1. define your_package.YourTask (or your_package.YourModel) class 2. create file /path/to/your_config/task/your_task.yaml (or /path/to/your_config/model/your_model.yaml) # @package _group_ - _target_: your_package.YourTask # or YourModel + _target_: your_package.YourTask # or YourModel param1: value1 param2: value2 3. call pyannote-audio-train --config-dir /path/to/your_config task=your_task task.param1=modified_value1 model=your_model ... diff --git a/pyannote/audio/cli/train_config/task/VoiceTypeClassification.yaml b/pyannote/audio/cli/train_config/task/VoiceTypeClassification.yaml new file mode 100644 index 000000000..7093f9008 --- /dev/null +++ b/pyannote/audio/cli/train_config/task/VoiceTypeClassification.yaml @@ -0,0 +1,9 @@ +# @package _group_ +_target_: pyannote.audio.tasks.VoiceTypeClassification +duration: 2.0 +warm_up: 0.0 +balance: null +weight: null +batch_size: 32 +num_workers: null +pin_memory: False \ No newline at end of file diff --git a/pyannote/audio/tasks/__init__.py b/pyannote/audio/tasks/__init__.py index 0892d8b24..d155884e0 100644 --- a/pyannote/audio/tasks/__init__.py +++ b/pyannote/audio/tasks/__init__.py @@ -25,7 +25,7 @@ from .segmentation.overlapped_speech_detection import ( # isort:skip OverlappedSpeechDetection, ) -from .segmentation.voice_type_classification import VoiceTypeClassification # isort:skip +from .segmentation.voice_type_classification import VoiceTypeClassification # isort:skip from .segmentation.speaker_tracking import SpeakerTracking # isort:skip From 5cc74dabacff7951be85d6ed0bc6d7ae807dcde7 Mon Sep 17 00:00:00 2001 From: Hadrien Titeux Date: Fri, 28 Jan 2022 16:25:33 +0100 Subject: [PATCH 4/8] Fixing Fscore metric, fixing MultilabelPipeline apply code --- .../audio/pipelines/multilabel_detection.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/pyannote/audio/pipelines/multilabel_detection.py b/pyannote/audio/pipelines/multilabel_detection.py index 427d827c0..06ee76a9c 100644 --- a/pyannote/audio/pipelines/multilabel_detection.py +++ b/pyannote/audio/pipelines/multilabel_detection.py @@ -129,6 +129,9 @@ class MultilabelFMeasure(BaseMetric): """ + def metric_components(self): + return ["AVG[Fscore]"] + self.mtl_specs.all_classes + @classmethod def metric_name(cls): return "AVG[Labels]" @@ -138,12 +141,12 @@ def __init__(self, mtl_specs: MultilabelDetectionSpecifications, # noqa beta=1., parallel=False, **kwargs): self.parallel = parallel self.metric_name_ = self.metric_name() - self.components_ = set(self.metric_components()) - self.reset() self.collar = collar self.skip_overlap = skip_overlap self.beta = beta self.mtl_specs = mtl_specs + self.components_ = set(self.metric_components()) + self.submetrics: Dict[str, DetectionPrecisionRecallFMeasure] = { label: DetectionPrecisionRecallFMeasure(collar=collar, skip_overlap=skip_overlap, @@ -152,6 +155,8 @@ def __init__(self, mtl_specs: MultilabelDetectionSpecifications, # noqa for label in self.mtl_specs.all_classes } + self.reset() + def reset(self): super().reset() for submetric in self.submetrics.values(): @@ -169,7 +174,7 @@ def compute_components(self, reference: Annotation, hypothesis: Annotation, uem= return details def compute_metric(self, detail: Dict[str, float]): - return np.mean(detail.values()) + return np.mean(list(detail.values())) def __abs__(self): return np.mean([abs(submetric) for submetric in self.submetrics.values()]) @@ -262,10 +267,8 @@ def apply(self, file: AudioFile) -> Annotation: speech : `pyannote.core.Annotation` Annotated classification. """ - if self.training: - if self.CACHED_ACTIVATIONS not in file: - file[self.CACHED_ACTIVATIONS] = self.segmentation_inference_(file) - else: + if self.CACHED_ACTIVATIONS not in file: + print(f"computing activation for file {file['uri']}") file[self.CACHED_ACTIVATIONS] = self.segmentation_inference_(file) # for each class name, add @@ -280,9 +283,10 @@ def apply(self, file: AudioFile) -> Annotation: label_scores = SlidingWindowFeature(label_scores_array, multilabel_scores.sliding_window) binarizer: Binarize = self._binarizers[class_name] - label_annot = binarizer(label_scores) - full_annot.update(label_annot) - + class_annot = binarizer(label_scores) + class_tl = class_annot.support().get_timeline() + for seg in class_tl: + full_annot[seg] = class_name return full_annot def get_metric(self) -> Union[MultilabelFMeasure, IdentificationErrorRate]: From 934cf011eb389aee1ffa60f8d19370deab526189 Mon Sep 17 00:00:00 2001 From: Hadrien Titeux Date: Sat, 29 Jan 2022 03:12:37 +0100 Subject: [PATCH 5/8] Fixed multilabel pipeline apply method. --- .../audio/pipelines/multilabel_detection.py | 29 +++++++++++-------- .../pipelines/voice_activity_detection.py | 1 + 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/pyannote/audio/pipelines/multilabel_detection.py b/pyannote/audio/pipelines/multilabel_detection.py index 06ee76a9c..b870c4172 100644 --- a/pyannote/audio/pipelines/multilabel_detection.py +++ b/pyannote/audio/pipelines/multilabel_detection.py @@ -26,6 +26,9 @@ import numpy as np from numba.typed import List +from pyannote.audio import Inference +from pyannote.audio.core.io import AudioFile +from pyannote.audio.core.pipeline import Pipeline from pyannote.core import Annotation, SlidingWindowFeature from pyannote.metrics.base import BaseMetric from pyannote.metrics.detection import DetectionPrecisionRecallFMeasure @@ -33,9 +36,6 @@ from pyannote.pipeline.parameter import ParamDict, Uniform from sortedcontainers import SortedDict -from pyannote.audio import Inference -from pyannote.audio.core.io import AudioFile -from pyannote.audio.core.pipeline import Pipeline from .utils import PipelineModel, get_devices, get_model from ..utils.signal import Binarize @@ -130,7 +130,7 @@ class MultilabelFMeasure(BaseMetric): """ def metric_components(self): - return ["AVG[Fscore]"] + self.mtl_specs.all_classes + return self.mtl_specs.all_classes @classmethod def metric_name(cls): @@ -267,12 +267,17 @@ def apply(self, file: AudioFile) -> Annotation: speech : `pyannote.core.Annotation` Annotated classification. """ - if self.CACHED_ACTIVATIONS not in file: - print(f"computing activation for file {file['uri']}") - file[self.CACHED_ACTIVATIONS] = self.segmentation_inference_(file) - # for each class name, add - multilabel_scores: SlidingWindowFeature = file[self.CACHED_ACTIVATIONS] + multilabel_scores: SlidingWindowFeature + if self.training: + if self.CACHED_ACTIVATIONS not in file: + file[self.CACHED_ACTIVATIONS] = self.segmentation_inference_(file) + + multilabel_scores = file[self.CACHED_ACTIVATIONS] + else: + multilabel_scores = self.segmentation_inference_(file) + + # for each class name, add class-specific "VAD" pipeline full_annot = Annotation(uri=file["uri"]) for class_idx, class_name in enumerate(self.labels): # selecting scores for only one label @@ -284,9 +289,9 @@ def apply(self, file: AudioFile) -> Annotation: multilabel_scores.sliding_window) binarizer: Binarize = self._binarizers[class_name] class_annot = binarizer(label_scores) - class_tl = class_annot.support().get_timeline() - for seg in class_tl: - full_annot[seg] = class_name + class_annot.rename_labels({label: class_name for label in class_annot.labels()}, copy=False) + full_annot.update(class_annot) + return full_annot def get_metric(self) -> Union[MultilabelFMeasure, IdentificationErrorRate]: diff --git a/pyannote/audio/pipelines/voice_activity_detection.py b/pyannote/audio/pipelines/voice_activity_detection.py index 104fc84f1..66919c45d 100644 --- a/pyannote/audio/pipelines/voice_activity_detection.py +++ b/pyannote/audio/pipelines/voice_activity_detection.py @@ -178,6 +178,7 @@ def apply(self, file: AudioFile) -> Annotation: if self.CACHED_ACTIVATIONS not in file: file[self.CACHED_ACTIVATIONS] = self.segmentation_inference_(file) else: + # TODO : isn't this supposed to be a no-cache inference? file[self.CACHED_ACTIVATIONS] = self.segmentation_inference_(file) speech: Annotation = self._binarize(file[self.CACHED_ACTIVATIONS]) From 1f0c63eea87c1fc07c2276788aaf59cda2a50140 Mon Sep 17 00:00:00 2001 From: hadware Date: Thu, 3 Feb 2022 12:50:34 +0100 Subject: [PATCH 6/8] Fixing imports --- pyannote/audio/pipelines/multilabel_detection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyannote/audio/pipelines/multilabel_detection.py b/pyannote/audio/pipelines/multilabel_detection.py index 427d827c0..d7b5d78b1 100644 --- a/pyannote/audio/pipelines/multilabel_detection.py +++ b/pyannote/audio/pipelines/multilabel_detection.py @@ -25,7 +25,6 @@ from typing import Union, Optional, List, Dict, TYPE_CHECKING, Text import numpy as np -from numba.typed import List from pyannote.core import Annotation, SlidingWindowFeature from pyannote.metrics.base import BaseMetric from pyannote.metrics.detection import DetectionPrecisionRecallFMeasure From b0ec1a28e1b049cd63b00893430ef1382b0869e9 Mon Sep 17 00:00:00 2001 From: hadware Date: Thu, 3 Feb 2022 12:53:21 +0100 Subject: [PATCH 7/8] Fixing imports (again) --- pyannote/audio/pipelines/multilabel_detection.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pyannote/audio/pipelines/multilabel_detection.py b/pyannote/audio/pipelines/multilabel_detection.py index b870c4172..c5dd7671d 100644 --- a/pyannote/audio/pipelines/multilabel_detection.py +++ b/pyannote/audio/pipelines/multilabel_detection.py @@ -25,10 +25,6 @@ from typing import Union, Optional, List, Dict, TYPE_CHECKING, Text import numpy as np -from numba.typed import List -from pyannote.audio import Inference -from pyannote.audio.core.io import AudioFile -from pyannote.audio.core.pipeline import Pipeline from pyannote.core import Annotation, SlidingWindowFeature from pyannote.metrics.base import BaseMetric from pyannote.metrics.detection import DetectionPrecisionRecallFMeasure @@ -36,6 +32,9 @@ from pyannote.pipeline.parameter import ParamDict, Uniform from sortedcontainers import SortedDict +from pyannote.audio import Inference +from pyannote.audio.core.io import AudioFile +from pyannote.audio.core.pipeline import Pipeline from .utils import PipelineModel, get_devices, get_model from ..utils.signal import Binarize From a28bacbc0815fb775f0d3c5c57f84b9f560cf3b4 Mon Sep 17 00:00:00 2001 From: hadware Date: Thu, 3 Feb 2022 13:03:52 +0100 Subject: [PATCH 8/8] Fixing imports (again^2) --- pyannote/audio/tasks/segmentation/voice_type_classification.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/voice_type_classification.py b/pyannote/audio/tasks/segmentation/voice_type_classification.py index 0363b10f0..fb86a2773 100644 --- a/pyannote/audio/tasks/segmentation/voice_type_classification.py +++ b/pyannote/audio/tasks/segmentation/voice_type_classification.py @@ -19,10 +19,9 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Tuple, Union, Optional, Text +from typing import Tuple, Union, Optional, Text, List import numpy as np -from numba.typed import List from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform