In [1]:
!pip install pyannote.audio



In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
# REPLACE THIS WITH THE PATH TO THE FINE-TUNED SEGMENTATION MODEL YOU'RE TRYING TO TEST
path_to_model = "outputs/fine_tuned_models/english_finetune_1epoch.ckpt"

In [4]:
import os
os.chdir('drive/MyDrive/CS224S_Final_Project/data')

In [5]:
# Use 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
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [6]:
from pyannote.database import registry, FileFinder
registry.load_database("database.yml")

multilingual_data = registry.get_protocol("classbank.SpeakerDiarization.multilingual", {"audio": FileFinder()})
aus_data = registry.get_protocol("classbank.SpeakerDiarization.aus_only", {"audio": FileFinder()})
west_data = registry.get_protocol("classbank.SpeakerDiarization.us-aus-ned", {"audio": FileFinder()})
east_data = registry.get_protocol("classbank.SpeakerDiarization.jap-hk", {"audio": FileFinder()})
hk_data = registry.get_protocol("classbank.SpeakerDiarization.hk_only", {"audio": FileFinder()})
english_data = registry.get_protocol("classbank.SpeakerDiarization.english", {"audio": FileFinder()})



In [7]:
from pyannote.audio import Model
# Change with the name of the segmentation model I fine-tuned
seg_model = Model.from_pretrained(path_to_model)

In [8]:
from pyannote.audio import Pipeline
pretrained_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token='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')

In [9]:
# default params
default_params = {'segmentation': {'min_duration_off': 0.5817029604921046, 'threshold': 0.4442333667381752},
                  'clustering': {'method': 'centroid', 'min_cluster_size': 15, 'threshold': 0.7153814381597874}}

In [10]:
from pyannote.audio.pipelines import SpeakerDiarization
import torch

ft_pipeline = SpeakerDiarization(
    segmentation=seg_model,
    embedding=pretrained_pipeline.embedding,
    embedding_exclude_overlap=pretrained_pipeline.embedding_exclude_overlap,
    clustering=pretrained_pipeline.klustering,
)

ft_pipeline.instantiate(default_params)
ft_pipeline.to(torch.device("cuda"))


<pyannote.audio.pipelines.speaker_diarization.SpeakerDiarization at 0x78e1a94cd2d0>

In [11]:
from pyannote.metrics.diarization import DiarizationErrorRate

metric_aus = DiarizationErrorRate()
metric_hk = DiarizationErrorRate()
metric_east = DiarizationErrorRate()
metric_west = DiarizationErrorRate()
metric_multilingual = DiarizationErrorRate()
metric_english = DiarizationErrorRate()

In [11]:
# AUSTRALIAN DATA
from pyannote.audio.pipelines.utils.hook import ProgressHook
from pyannote.audio import Audio

counter=0
with ProgressHook() as hook:
  for file in aus_data.test():
      io = Audio(mono='downmix', sample_rate=16000)
      waveform, sample_rate = io(file)

      file["finetuned pipeline"] = ft_pipeline({"waveform":waveform, "sample_rate":sample_rate}, hook=hook)
      der = metric_aus(file["annotation"], file["finetuned pipeline"], uem=file["annotated"]) # use this line instead since we don't have uems yet
      counter+=1
      print(f"Finished running inference on example #{counter}, on filename {file['uri']} from Australia. Got a DER of {der}.")

print(metric_aus)

Output()

           diarization error rate   total correct correct false alarm false alarm missed detection missed detection confusion confusion
                                %                       %                       %                                 %                   %
item                                                                                                                                   
1004lv104                   20.16  567.40  453.01   79.84        0.00        0.00             2.23             0.39    112.16     19.77
TK09091301                  32.97 1982.07 1423.78   71.83       95.26        4.81            23.72             1.20    534.57     26.97
CC06301748                  21.96 4373.78 3476.01   79.47       62.52        1.43             2.24             0.05    895.53     20.47
1004lv203                   40.58   96.23   57.84   60.11        0.66        0.68             0.03             0.03     38.36     39.86
1004lv103                   11.71  567.40  500.9

In [12]:
# HONGKONG DATA
from pyannote.audio.pipelines.utils.hook import ProgressHook
from pyannote.audio import Audio

counter=0
with ProgressHook() as hook:
  for file in hk_data.test():
      io = Audio(mono='downmix', sample_rate=16000)
      waveform, sample_rate = io(file)

      file["finetuned pipeline"] = ft_pipeline({"waveform":waveform, "sample_rate":sample_rate}, hook=hook)
      der = metric_hk(file["annotation"], file["finetuned pipeline"], uem=file["annotated"]) # use this line instead since we don't have uems yet
      counter+=1
      print(f"Finished running inference on example #{counter}, on filename {file['uri']} from HongKong. Got a DER of {der}.")

print(metric_hk)

Output()

         diarization error rate  total correct correct false alarm false alarm missed detection missed detection confusion confusion
                              %                      %                       %                                 %                   %
item                                                                                                                                
3003lv02                  25.61 567.78  422.41   74.40        0.05        0.01             0.03             0.01    145.34     25.60
3004lv04                   4.42 225.48  215.52   95.58        0.00        0.00             0.03             0.01      9.93      4.40
TOTAL                     19.59 793.27  637.93   80.42        0.05        0.01             0.06             0.01    155.27     19.57


In [13]:
# EAST DATA
from pyannote.audio.pipelines.utils.hook import ProgressHook
from pyannote.audio import Audio

counter=0
with ProgressHook() as hook:
  for file in east_data.test():
      io = Audio(mono='downmix', sample_rate=16000)
      waveform, sample_rate = io(file)

      file["finetuned pipeline"] = ft_pipeline({"waveform":waveform, "sample_rate":sample_rate}, hook=hook)
      der = metric_east(file["annotation"], file["finetuned pipeline"], uem=file["annotated"]) # use this line instead since we don't have uems yet
      counter+=1
      print(f"Finished running inference on example #{counter}, on filename {file['uri']} from HK+Japan. Got a DER of {der}.")

print(metric_east)

Output()

           diarization error rate   total correct correct false alarm false alarm missed detection missed detection confusion confusion
                                %                       %                       %                                 %                   %
item                                                                                                                                   
3001lv103                   27.81  568.03  410.58   72.28        0.54        0.10             0.03             0.01    157.42     27.71
3004lv02                    14.96  567.59  482.66   85.04        0.00        0.00             0.03             0.01     84.90     14.96
TK09051822                  33.90 2862.65 2029.97   70.91      137.74        4.81            30.88             1.08    801.80     28.01
TOTAL                       30.35 3998.28 2923.21   73.11      138.28        3.46            30.94             0.77   1044.12     26.11


In [12]:
# WEST DATA
from pyannote.audio.pipelines.utils.hook import ProgressHook
from pyannote.audio import Audio

counter=0
with ProgressHook() as hook:
  for file in west_data.test():
      io = Audio(mono='downmix', sample_rate=16000)
      waveform, sample_rate = io(file)

      file["finetuned pipeline"] = ft_pipeline({"waveform":waveform, "sample_rate":sample_rate}, hook=hook)
      der = metric_west(file["annotation"], file["finetuned pipeline"], uem=file["annotated"]) # use this line instead since we don't have uems yet
      counter+=1
      print(f"Finished running inference on example #{counter}, on filename {file['uri']} from US+Ned+Aus. Got a DER of {der}.")

print(metric_west)

Output()

           diarization error rate    total  correct correct false alarm false alarm missed detection missed detection confusion confusion
                                %                         %                       %                                 %                   %
item                                                                                                                                     
6024us403                   32.16   567.47   391.28   68.95        6.29        1.11            12.60             2.22    163.58     28.83
6013us105                   27.65   392.87   289.04   73.57        4.80        1.22             3.61             0.92    100.22     25.51
4003nl104                   25.95   568.10   425.12   74.83        4.46        0.78             2.83             0.50    140.14     24.67
6019us203                   19.87   568.07   456.62   80.38        1.43        0.25            15.02             2.64     96.43     16.97
1004lv106                   23.18 

In [13]:
# MULTILINGUAL DATA
from pyannote.audio.pipelines.utils.hook import ProgressHook
from pyannote.audio import Audio

counter=0
with ProgressHook() as hook:
  for file in multilingual_data.test():
      io = Audio(mono='downmix', sample_rate=16000)
      waveform, sample_rate = io(file)

      file["finetuned pipeline"] = ft_pipeline({"waveform":waveform, "sample_rate":sample_rate}, hook=hook)
      der = metric_multilingual(file["annotation"], file["finetuned pipeline"], uem=file["annotated"]) # use this line instead since we don't have uems yet
      counter+=1
      print(f"Finished running inference on example #{counter}, on filename {file['uri']} from all languages set. Got a DER of {der}.")

print(metric_multilingual)

Output()

           diarization error rate    total  correct correct false alarm false alarm missed detection missed detection confusion confusion
                                %                         %                       %                                 %                   %
item                                                                                                                                     
6013us101                   25.36   558.65   423.67   75.84        6.68        1.20            12.79             2.29    122.20     21.87
2010cz302                   36.73   568.01   362.86   63.88        3.49        0.61             2.49             0.44    202.65     35.68
6024us405                   41.79   433.75   258.24   59.54        5.75        1.33             7.54             1.74    167.97     38.73
CC11241853                  38.32  2358.86  1814.19   76.91      359.24       15.23            49.34             2.09    495.33     21.00
2010cz303                   31.19 

In [14]:
# ENGLISH DATA
from pyannote.audio.pipelines.utils.hook import ProgressHook
from pyannote.audio import Audio

counter=0
with ProgressHook() as hook:
  for file in english_data.test():
      io = Audio(mono='downmix', sample_rate=16000)
      waveform, sample_rate = io(file)

      file["finetuned pipeline"] = ft_pipeline({"waveform":waveform, "sample_rate":sample_rate}, hook=hook)
      der = metric_english(file["annotation"], file["finetuned pipeline"], uem=file["annotated"]) # use this line instead since we don't have uems yet
      counter+=1
      print(f"Finished running inference on example #{counter}, on filename {file['uri']} from English set. Got a DER of {der}.")

print(metric_english)

Output()

           diarization error rate    total correct correct false alarm false alarm missed detection missed detection confusion confusion
                                %                        %                       %                                 %                   %
item                                                                                                                                    
6024us405                   41.79   433.75  258.24   59.54        5.75        1.33             7.54             1.74    167.97     38.73
6013us102                   26.95   567.30  418.64   73.79        4.25        0.75             4.55             0.80    144.11     25.40
1003lv02                    42.79   568.06  325.30   57.27        0.32        0.06             0.03             0.01    242.72     42.73
6013us101                   25.36   558.65  423.67   75.84        6.68        1.20            12.79             2.29    122.20     21.87
1001lv04                    24.52   557.4