In [1]:
!pip install pyannote.audio



In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
# REPLACE THIS WITH THE PATH TO THE FINE-TUNED SEGMENTATION MODEL YOU'RE TRYING TO TEST
path_to_model = "outputs/fine_tuned_models/hk_finetune_2epoch.ckpt"

In [4]:
import os
os.chdir('drive/MyDrive/CS224S_Final_Project/data')

In [5]:
# Use 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
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [6]:
from pyannote.database import registry, FileFinder
registry.load_database("database.yml")

multilingual_data = registry.get_protocol("classbank.SpeakerDiarization.multilingual", {"audio": FileFinder()})
aus_data = registry.get_protocol("classbank.SpeakerDiarization.aus_only", {"audio": FileFinder()})
west_data = registry.get_protocol("classbank.SpeakerDiarization.us-aus-ned", {"audio": FileFinder()})
east_data = registry.get_protocol("classbank.SpeakerDiarization.jap-hk", {"audio": FileFinder()})
hk_data = registry.get_protocol("classbank.SpeakerDiarization.hk_only", {"audio": FileFinder()})



In [7]:
from pyannote.audio import Model
# Change with the name of the segmentation model I fine-tuned
seg_model = Model.from_pretrained(path_to_model)

In [8]:
from pyannote.audio import Pipeline
pretrained_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token='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')

In [9]:
# default params
default_params = {'segmentation': {'min_duration_off': 0.5817029604921046, 'threshold': 0.4442333667381752},
                  'clustering': {'method': 'centroid', 'min_cluster_size': 15, 'threshold': 0.7153814381597874}}

In [10]:
from pyannote.audio.pipelines import SpeakerDiarization
import torch

ft_pipeline = SpeakerDiarization(
    segmentation=seg_model,
    embedding=pretrained_pipeline.embedding,
    embedding_exclude_overlap=pretrained_pipeline.embedding_exclude_overlap,
    clustering=pretrained_pipeline.klustering,
)

ft_pipeline.instantiate(default_params)
ft_pipeline.to(torch.device("cuda"))


<pyannote.audio.pipelines.speaker_diarization.SpeakerDiarization at 0x7f82638c6e90>

In [11]:
from pyannote.metrics.diarization import DiarizationErrorRate

metric_aus = DiarizationErrorRate()
metric_hk = DiarizationErrorRate()
metric_east = DiarizationErrorRate()
metric_west = DiarizationErrorRate()
metric_multilingual = DiarizationErrorRate()

In [None]:
# AUSTRALIAN DATA
from pyannote.audio.pipelines.utils.hook import ProgressHook
from pyannote.audio import Audio

counter=0
with ProgressHook() as hook:
  for file in aus_data.test():
      io = Audio(mono='downmix', sample_rate=16000)
      waveform, sample_rate = io(file)

      file["finetuned pipeline"] = ft_pipeline({"waveform":waveform, "sample_rate":sample_rate}, hook=hook)
      der = metric_aus(file["annotation"], file["finetuned pipeline"], uem=file["annotated"]) # use this line instead since we don't have uems yet
      counter+=1
      print(f"Finished running inference on example #{counter}, on filename {file['uri']} from Australia. Got a DER of {der}.")

print(metric_aus)

Output()

           diarization error rate   total correct correct false alarm false alarm missed detection missed detection confusion confusion
                                %                       %                       %                                 %                   %
item                                                                                                                                   
1004lv104                   26.11  567.40  467.30   82.36       48.06        8.47             3.92             0.69     96.18     16.95
TK09091301                  53.52 1982.07 1522.65   76.82      601.37       30.34           108.92             5.50    350.50     17.68
CC06301748                  26.68 4373.78 3438.71   78.62      231.82        5.30            49.77             1.14    885.30     20.24
1004lv203                   40.60   96.23   91.29   94.87       34.14       35.48             0.71             0.73      4.23      4.39
1004lv103                   19.59  567.40  508.2

In [None]:
# HONGKONG DATA
from pyannote.audio.pipelines.utils.hook import ProgressHook
from pyannote.audio import Audio

counter=0
with ProgressHook() as hook:
  for file in hk_data.test():
      io = Audio(mono='downmix', sample_rate=16000)
      waveform, sample_rate = io(file)

      file["finetuned pipeline"] = ft_pipeline({"waveform":waveform, "sample_rate":sample_rate}, hook=hook)
      der = metric_hk(file["annotation"], file["finetuned pipeline"], uem=file["annotated"]) # use this line instead since we don't have uems yet
      counter+=1
      print(f"Finished running inference on example #{counter}, on filename {file['uri']} from HongKong. Got a DER of {der}.")

print(metric_hk)

Output()

         diarization error rate  total correct correct false alarm false alarm missed detection missed detection confusion confusion
                              %                      %                       %                                 %                   %
item                                                                                                                                
3003lv02                  25.75 567.78  422.02   74.33        0.44        0.08             0.03             0.01    145.73     25.67
3004lv04                   4.42 225.48  215.52   95.58        0.00        0.00             0.03             0.01      9.93      4.40
TOTAL                     19.69 793.27  637.54   80.37        0.44        0.06             0.06             0.01    155.66     19.62


In [None]:
# EAST DATA
from pyannote.audio.pipelines.utils.hook import ProgressHook
from pyannote.audio import Audio

counter=0
with ProgressHook() as hook:
  for file in east_data.test():
      io = Audio(mono='downmix', sample_rate=16000)
      waveform, sample_rate = io(file)

      file["finetuned pipeline"] = ft_pipeline({"waveform":waveform, "sample_rate":sample_rate}, hook=hook)
      der = metric_east(file["annotation"], file["finetuned pipeline"], uem=file["annotated"]) # use this line instead since we don't have uems yet
      counter+=1
      print(f"Finished running inference on example #{counter}, on filename {file['uri']} from HK+Japan. Got a DER of {der}.")

print(metric_east)

Output()

           diarization error rate   total correct correct false alarm false alarm missed detection missed detection confusion confusion
                                %                       %                       %                                 %                   %
item                                                                                                                                   
3001lv103                   28.55  568.03  406.90   71.63        1.05        0.18             0.03             0.01    161.11     28.36
3004lv02                    14.98  567.59  482.59   85.02        0.00        0.00             0.03             0.01     84.97     14.97
TK09051822                  39.26 2862.65 2065.22   72.14      326.52       11.41            22.63             0.79    774.80     27.07
TOTAL                       34.29 3998.28 2954.71   73.90      327.57        8.19            22.69             0.57   1020.87     25.53


In [12]:
# WEST DATA
from pyannote.audio.pipelines.utils.hook import ProgressHook
from pyannote.audio import Audio

counter=0
with ProgressHook() as hook:
  for file in west_data.test():
      io = Audio(mono='downmix', sample_rate=16000)
      waveform, sample_rate = io(file)

      file["finetuned pipeline"] = ft_pipeline({"waveform":waveform, "sample_rate":sample_rate}, hook=hook)
      der = metric_west(file["annotation"], file["finetuned pipeline"], uem=file["annotated"]) # use this line instead since we don't have uems yet
      counter+=1
      print(f"Finished running inference on example #{counter}, on filename {file['uri']} from US+Ned+Aus. Got a DER of {der}.")

print(metric_west)

Output()

           diarization error rate    total  correct correct false alarm false alarm missed detection missed detection confusion confusion
                                %                         %                       %                                 %                   %
item                                                                                                                                     
6024us403                   33.88   567.47   379.12   66.81        3.92        0.69            11.59             2.04    176.76     31.15
6013us105                   41.68   392.87   311.46   79.28       82.32       20.95             8.77             2.23     72.64     18.49
4003nl104                   34.52   568.10   414.04   72.88       42.05        7.40             6.78             1.19    147.27     25.92
6019us203                   22.24   568.07   454.63   80.03       12.91        2.27             8.50             1.50    104.94     18.47
1004lv106                   34.07 

In [13]:
# MULTILINGUAL DATA
from pyannote.audio.pipelines.utils.hook import ProgressHook
from pyannote.audio import Audio

counter=0
with ProgressHook() as hook:
  for file in multilingual_data.test():
      io = Audio(mono='downmix', sample_rate=16000)
      waveform, sample_rate = io(file)

      file["finetuned pipeline"] = ft_pipeline({"waveform":waveform, "sample_rate":sample_rate}, hook=hook)
      der = metric_multilingual(file["annotation"], file["finetuned pipeline"], uem=file["annotated"]) # use this line instead since we don't have uems yet
      counter+=1
      print(f"Finished running inference on example #{counter}, on filename {file['uri']} from all languages set. Got a DER of {der}.")

print(metric_multilingual)

Output()

           diarization error rate    total  correct correct false alarm false alarm missed detection missed detection confusion confusion
                                %                         %                       %                                 %                   %
item                                                                                                                                     
6013us101                   30.81   558.65   432.61   77.44       46.07        8.25             8.77             1.57    117.27     20.99
2010cz302                   37.30   568.01   362.38   63.80        6.24        1.10             0.62             0.11    205.00     36.09
6024us405                   41.53   433.75   255.47   58.90        1.84        0.42             2.29             0.53    175.99     40.57
CC11241853                  44.73  2358.86  1796.38   76.15      492.75       20.89            35.90             1.52    526.58     22.32
2010cz303                   32.26 

In [14]:
print(metric_multilingual)

           diarization error rate    total  correct correct false alarm false alarm missed detection missed detection confusion confusion
                                %                         %                       %                                 %                   %
item                                                                                                                                     
6013us101                   30.81   558.65   432.61   77.44       46.07        8.25             8.77             1.57    117.27     20.99
2010cz302                   37.30   568.01   362.38   63.80        6.24        1.10             0.62             0.11    205.00     36.09
6024us405                   41.53   433.75   255.47   58.90        1.84        0.42             2.29             0.53    175.99     40.57
CC11241853                  44.73  2358.86  1796.38   76.15      492.75       20.89            35.90             1.52    526.58     22.32
2010cz303                   32.26 