In [1]:
# REPLACE THIS WITH THE PATH TO THE FINE-TUNED SEGMENTATION MODEL YOU'RE TRYING TO TEST
path_to_model = "outputs/fine_tuned_models/pyannote_ausonly_finetune_1epoch.ckpt"

In [2]:
!pip install pyannote.audio



In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
import os
os.chdir('drive/MyDrive/CS224S_Final_Project/data')

In [3]:
# Use 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
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [5]:
from pyannote.database import registry, FileFinder
registry.load_database("database.yml")

multilingual_data = registry.get_protocol("classbank.SpeakerDiarization.multilingual", {"audio": FileFinder()})
aus_data = registry.get_protocol("classbank.SpeakerDiarization.aus_only", {"audio": FileFinder()})
west_data = registry.get_protocol("classbank.SpeakerDiarization.us-aus-ned", {"audio": FileFinder()})
east_data = registry.get_protocol("classbank.SpeakerDiarization.jap-hk", {"audio": FileFinder()})
hk_data = registry.get_protocol("classbank.SpeakerDiarization.hk_only", {"audio": FileFinder()})



In [6]:
from pyannote.audio import Model
# Change with the name of the segmentation model I fine-tuned
seg_model = Model.from_pretrained(path_to_model)

In [7]:
from pyannote.audio import Pipeline
pretrained_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token='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')

In [8]:
# default params
default_params = {'segmentation': {'min_duration_off': 0.5817029604921046, 'threshold': 0.4442333667381752},
                  'clustering': {'method': 'centroid', 'min_cluster_size': 15, 'threshold': 0.7153814381597874}}

In [9]:
from pyannote.audio.pipelines import SpeakerDiarization
import torch

ft_pipeline = SpeakerDiarization(
    segmentation=seg_model,
    embedding=pretrained_pipeline.embedding,
    embedding_exclude_overlap=pretrained_pipeline.embedding_exclude_overlap,
    clustering=pretrained_pipeline.klustering,
)

ft_pipeline.instantiate(default_params)
ft_pipeline.to(torch.device("cuda"))



<pyannote.audio.pipelines.speaker_diarization.SpeakerDiarization at 0x7e90666dbf10>

In [12]:
from pyannote.metrics.diarization import DiarizationErrorRate

metric_aus = DiarizationErrorRate()
metric_hk = DiarizationErrorRate()
metric_east = DiarizationErrorRate()
metric_west = DiarizationErrorRate()
metric_multilingual = DiarizationErrorRate()

In [13]:
# AUSTRALIAN DATA
from pyannote.audio.pipelines.utils.hook import ProgressHook
from pyannote.audio import Audio

counter=0
with ProgressHook() as hook:
  for file in aus_data.test():
      io = Audio(mono='downmix', sample_rate=16000)
      waveform, sample_rate = io(file)

      file["finetuned pipeline"] = ft_pipeline({"waveform":waveform, "sample_rate":sample_rate}, hook=hook)
      der = metric_aus(file["annotation"], file["finetuned pipeline"], uem=file["annotated"]) # use this line instead since we don't have uems yet
      counter+=1
      print(f"Finished running inference on example #{counter}, on filename {file['uri']} from Australia. Got a DER of {der}.")

print(metric_aus)

Output()

           diarization error rate   total correct correct false alarm false alarm missed detection missed detection confusion confusion
                                %                       %                       %                                 %                   %
item                                                                                                                                   
1004lv104                   20.16  567.40  453.01   79.84        0.00        0.00             0.11             0.02    114.29     20.14
TK09091301                  34.23 1982.07 1383.87   69.82       80.23        4.05             0.03             0.00    598.17     30.18
CC06301748                  21.67 4373.78 3478.05   79.52       52.24        1.19             0.08             0.00    895.65     20.48
1004lv203                   41.35   96.23   56.44   58.65        0.00        0.00             0.03             0.03     39.76     41.32
1004lv103                   11.71  567.40  500.9

In [14]:
# HONGKONG DATA
from pyannote.audio.pipelines.utils.hook import ProgressHook
from pyannote.audio import Audio

counter=0
with ProgressHook() as hook:
  for file in hk_data.test():
      io = Audio(mono='downmix', sample_rate=16000)
      waveform, sample_rate = io(file)

      file["finetuned pipeline"] = ft_pipeline({"waveform":waveform, "sample_rate":sample_rate}, hook=hook)
      der = metric_hk(file["annotation"], file["finetuned pipeline"], uem=file["annotated"]) # use this line instead since we don't have uems yet
      counter+=1
      print(f"Finished running inference on example #{counter}, on filename {file['uri']} from HongKong. Got a DER of {der}.")

print(metric_hk)

Output()

         diarization error rate  total correct correct false alarm false alarm missed detection missed detection confusion confusion
                              %                      %                       %                                 %                   %
item                                                                                                                                
3003lv02                  25.36 567.78  425.65   74.97        1.87        0.33             0.03             0.01    142.10     25.03
3004lv04                   4.42 225.48  215.52   95.58        0.00        0.00             0.03             0.01      9.93      4.40
TOTAL                     19.41 793.27  641.17   80.83        1.87        0.24             0.06             0.01    152.03     19.17


In [15]:
# EAST DATA
from pyannote.audio.pipelines.utils.hook import ProgressHook
from pyannote.audio import Audio

counter=0
with ProgressHook() as hook:
  for file in east_data.test():
      io = Audio(mono='downmix', sample_rate=16000)
      waveform, sample_rate = io(file)

      file["finetuned pipeline"] = ft_pipeline({"waveform":waveform, "sample_rate":sample_rate}, hook=hook)
      der = metric_east(file["annotation"], file["finetuned pipeline"], uem=file["annotated"]) # use this line instead since we don't have uems yet
      counter+=1
      print(f"Finished running inference on example #{counter}, on filename {file['uri']} from HK+Japan. Got a DER of {der}.")

print(metric_east)

Output()

           diarization error rate   total correct correct false alarm false alarm missed detection missed detection confusion confusion
                                %                       %                       %                                 %                   %
item                                                                                                                                   
3001lv103                   27.73  568.03  411.46   72.44        0.96        0.17             0.03             0.01    156.54     27.56
3004lv02                    14.96  567.59  482.80   85.06        0.13        0.02             0.03             0.01     84.76     14.93
TK09051822                  33.08 2862.65 2060.11   71.97      144.49        5.05             0.00             0.00    802.53     28.03
TOTAL                       29.75 3998.28 2954.37   73.89      145.59        3.64             0.06             0.00   1043.84     26.11


In [12]:
from pyannote.metrics.diarization import DiarizationErrorRate
metric_west = DiarizationErrorRate()

In [13]:
# WEST DATA
from pyannote.audio.pipelines.utils.hook import ProgressHook
from pyannote.audio import Audio

counter=0
with ProgressHook() as hook:
  for file in west_data.test():
      io = Audio(mono='downmix', sample_rate=16000)
      waveform, sample_rate = io(file)

      file["finetuned pipeline"] = ft_pipeline({"waveform":waveform, "sample_rate":sample_rate}, hook=hook)
      der = metric_west(file["annotation"], file["finetuned pipeline"], uem=file["annotated"]) # use this line instead since we don't have uems yet
      counter+=1
      print(f"Finished running inference on example #{counter}, on filename {file['uri']} from US+Ned+Aus. Got a DER of {der}.")

print(metric_west)

Output()

           diarization error rate    total  correct correct false alarm false alarm missed detection missed detection confusion confusion
                                %                         %                       %                                 %                   %
item                                                                                                                                     
6024us403                   33.03   567.47   381.48   67.22        1.43        0.25             0.03             0.01    185.96     32.77
6013us105                   29.66   392.87   277.11   70.54        0.78        0.20             0.03             0.01    115.73     29.46
4003nl104                   33.34   568.10   378.70   66.66        0.00        0.00             0.03             0.01    189.37     33.33
6019us203                   19.77   568.07   455.78   80.23        0.00        0.00             0.03             0.01    112.26     19.76
1004lv106                   23.18 

In [14]:
metric_multilingual = DiarizationErrorRate()

In [15]:
# MULTILINGUAL DATA
from pyannote.audio.pipelines.utils.hook import ProgressHook
from pyannote.audio import Audio

counter=0
with ProgressHook() as hook:
  for file in multilingual_data.test():
      io = Audio(mono='downmix', sample_rate=16000)
      waveform, sample_rate = io(file)

      file["finetuned pipeline"] = ft_pipeline({"waveform":waveform, "sample_rate":sample_rate}, hook=hook)
      der = metric_multilingual(file["annotation"], file["finetuned pipeline"], uem=file["annotated"]) # use this line instead since we don't have uems yet
      counter+=1
      print(f"Finished running inference on example #{counter}, on filename {file['uri']} from all languages set. Got a DER of {der}.")

print(metric_multilingual)

Output()

           diarization error rate    total  correct correct false alarm false alarm missed detection missed detection confusion confusion
                                %                         %                       %                                 %                   %
item                                                                                                                                     
6013us101                   24.78   558.65   420.23   75.22        0.00        0.00             0.03             0.01    138.39     24.77
2010cz302                   36.46   568.01   360.90   63.54        0.00        0.00             0.03             0.01    207.07     36.46
6024us405                   42.11   433.75   251.97   58.09        0.88        0.20             0.03             0.01    181.75     41.90
CC11241853                  37.24  2358.86  1824.42   77.34      344.11       14.59             0.00             0.00    534.44     22.66
2010cz303                   31.45 

In [16]:
print(metric_multilingual)

           diarization error rate    total  correct correct false alarm false alarm missed detection missed detection confusion confusion
                                %                         %                       %                                 %                   %
item                                                                                                                                     
6013us101                   24.78   558.65   420.23   75.22        0.00        0.00             0.03             0.01    138.39     24.77
2010cz302                   36.46   568.01   360.90   63.54        0.00        0.00             0.03             0.01    207.07     36.46
6024us405                   42.11   433.75   251.97   58.09        0.88        0.20             0.03             0.01    181.75     41.90
CC11241853                  37.24  2358.86  1824.42   77.34      344.11       14.59             0.00             0.00    534.44     22.66
2010cz303                   31.45 