In [1]:
import logging
logging.disable(logging.CRITICAL)
from omegaconf import OmegaConf

from pyannote.metrics.diarization import DiarizationErrorRate

from nemo.utils import logging

In [2]:
from util.nemo_util import *
from model.NeMo_diarizer import *

## Initialize Model

We first set configuration

In [3]:
device_id = 1
MODEL_CONFIG = os.path.join('config','model_config.yaml')
config = OmegaConf.load(MODEL_CONFIG)
config.device = f'cuda:{device_id}'
config.verbose = True
model = NeMoDiarizer(cfg=config)
torch.set_default_device(config.device)

In [5]:
audio_dir = 'conversations'
filenames = []
for file in os.listdir(audio_dir):
    if file[-4:] == '.wav':
        filenames.append(os.path.join(audio_dir, file))

# get duration
durations = []

# get number of speakers
num_speakers = []

for filename in filenames:
    json_path = f'{filename[:-4]}.json'
    raw_json = json.loads(open(json_path, 'r').read())
    num_speaker = len(raw_json['participants'])
    # for p in raw_json['participants']:
    #     if p['name'] in ('Hearth', 'participant', 'Participant'):
    #         continue
    num_speakers.append(num_speaker)
    durations.append(raw_json['duration'])

In [6]:
model.diarize(filenames, durations)

splitting manifest: 100%|██████████| 21/21 [00:10<00:00,  1.97it/s]
vad: 100%|██████████| 1760/1760 [06:12<00:00,  4.73it/s]
                                                                 

KeyboardInterrupt: 

In [None]:
diarize_performance = {}
for file in filenames:
    # Merge discontinued labels
    merged_label = Annotation()

    filename = file.split('/')[-1]

    with open(os.path.join('outputs', 'pred_json', f'{filename[:-4]}_labels.txt')) as f:
        for line in f.readlines():
            start, end, speaker = line.split()
            start, end = float(start), float(end)
            merged_label[Segment(start, end)] = speaker

    # Evaluate metrics using merged label
    true_labels = rttm_to_labels(os.path.join(audio_dir, f'{filename[:-4]}.rttm'))
    reference = labels_to_pyannote_object(true_labels)

    performance = DiarizationErrorRate().compute_components(reference, merged_label)
    metrics = ['confusion', 'missed detection', 'false alarm']
    for metric in metrics:
        performance[metric] /= performance['total']
    performance['DER'] = sum(performance[metric] for metric in metrics)
    diarize_performance[filename] = performance



In [None]:
import pandas as pd

df = {x : [] for x in performance}
index = []

for filename in diarize_performance:
    for x in diarize_performance[filename]:
        df[x].append(diarize_performance[filename][x])
    index.append(filename)

df = pd.DataFrame(df, index=index)

df.to_csv('diarize_performance_no_oracle.csv')

In [None]:
df[df["DER"] <= 0.15]

Unnamed: 0,confusion,total,correct,false alarm,missed detection,DER
conversation-81.wav,0.020155,4999.4,4681.312,0.023827,0.04347,0.087452
conversation-70.wav,0.050032,2015.3,1825.4,0.021431,0.044197,0.11566
conversation-40.wav,0.045632,4483.7,4181.48,0.017209,0.021772,0.084613
conversation-67.wav,0.054663,5250.7,4700.945,0.027641,0.050038,0.132342
conversation-42.wav,0.023578,5551.2,5240.178,0.025278,0.03245,0.081306
conversation-80.wav,0.019234,3192.0,3000.495,0.013515,0.040761,0.07351
conversation-41.wav,0.042411,3882.2,3586.005,0.030042,0.033885,0.106338
conversation-7.wav,0.051628,2585.4,2321.58,0.011832,0.050414,0.113874


In [None]:
df

Unnamed: 0,confusion,total,correct,false alarm,missed detection,DER
conversation-61.wav,0.044377,2820.2,2508.27,0.050809,0.066229,0.161414
conversation-69.wav,0.085201,3829.6,3081.896,0.023342,0.110042,0.218585
conversation-81.wav,0.020155,4999.4,4681.312,0.023827,0.04347,0.087452
conversation-55.wav,0.10091,5908.2,4720.978,0.014153,0.100035,0.215098
conversation-70.wav,0.050032,2015.3,1825.4,0.021431,0.044197,0.11566
conversation-65.wav,0.128648,5459.8,4593.031,0.029136,0.030107,0.187891
conversation-59.wav,0.0727,4036.5,3519.085,0.046459,0.055484,0.174643
conversation-40.wav,0.045632,4483.7,4181.48,0.017209,0.021772,0.084613
conversation-79.wav,0.104786,2686.9,2168.88,0.021623,0.088008,0.214418
conversation-66.wav,0.439641,5044.9,2714.496,0.037293,0.022291,0.499226


In [None]:
df.describe()

Unnamed: 0,confusion,total,correct,false alarm,missed detection,DER
count,21.0,21.0,21.0,21.0,21.0,21.0
mean,0.128494,3955.793333,3257.507619,0.027606,0.065733,0.221833
std,0.12155,1443.440785,1331.835534,0.016045,0.041605,0.153253
min,0.019234,244.8,113.915,0.009092,0.021772,0.07351
25%,0.045632,2866.3,2321.58,0.015528,0.033885,0.113874
50%,0.085201,4036.5,3105.945,0.023458,0.050038,0.174643
75%,0.130677,5044.9,4593.031,0.030042,0.088008,0.218585
max,0.439641,5908.2,5240.178,0.066113,0.169649,0.546385
