In [7]:
import os
import numpy as np
from collections import defaultdict
from pyannote.core import Annotation, Segment
from pyannote.database.util import load_rttm
from pyannote.metrics.diarization import DiarizationErrorRate, JaccardErrorRate

# load rttm files for metrics calculating
base_dir = '/home/jovyan/work'
reference_path = '{}/datasets/voxconverse/test/rttm'.format(base_dir)
hypothesis_path = '{}/voxcon-test-rttm/rttm-colar-auto'.format(base_dir)

sample_ids = ["aepyx", "aiqwk", "bjruf", "bmsyn", "bxcfq", "byapz", "clfcg", "cqfmj", "crylr", "cvofp", "dgvwu", "dohag", "dxbbt", "dzsef", "eauve", "eazeq", "eguui", "epygx", "eqsta", "euqef", "fijfi", "fpfvy", "fqrnu", "fxnwf", "fyqoe", "gcfwp", "gtjow", "gtnjb", "gukoa", "guvqf", "gylzn", "gyomp", "hcyak", "heolf", "hhepf", "ibrnm", "ifwki", "iiprr", "ikhje", "jdrwl", "jjkrt", "jjvkx", "jrfaz", "jsbdo", "jttar", "jxpom", "jzkzt", "kajfh", "kmunk", "kpjud", "ktvto", "kvkje", "lbfnx", "ledhe", "lilfy", "ljpes", "lkikz", "lpola", "lscfc", "ltgmz", "lubpm", "luobn", "mjmgr", "msbyq", "mupzb", "myjoe", "nlvdr", "nprxc", "ocfop", "ofbxh", "olzkb", "ooxlj", "oqwpd", "otmpf", "ouvtt", "poucc", "ppexo", "pwnsw", "qadia", "qeejz", "qlrry", "qwepo", "rarij", "rmvsh", "rxulz", "sebyw", "sexgc", "sfdvy", "svxzm", "tkybe", "tpslg", "uedkc", "uqxlg", "usqam", "vncid", "vylyk", "vzuru", "wdvva", "wemos", "wprog", "wwzsk", "xggbk", "xkgos", "xlyov", "xmyyy", "xqxkt", "xtdcl", "xtzoq", "xvxwv", "ybhwz", "ylzez", "ytmef", "yukhy", "yzvon", "zedtj", "zfzlc", "zowse", "zqidv", "zztbo", "ralnu", "uicid", "laoyl", "jxydp", "pzxit", "upshw", "gfneh", "kzmyi", "nkqzr", "kgjaa", "dkabn", "eucfa", "erslt", "mclsr", "fzwtp", "dzxut", "pkwrt", "gmmwm", "leneg", "sxqvt", "pgtkk", "fuzfh", "vtzqw", "rsypp", "qxana", "optsn", "dxokr", "ptses", "isxwc", "gzhwb", "mhwyr", "duvox", "ezxso", "jgiyq", "rpkso", "kmjvh", "wcxfk", "gcvrb", "eddje", "pccww", "vuewy", "tvtoe", "oubab", "jwggf", "aggyz", "bidnq", "neiye", "mkhie", "iowob", "jbowg", "gwloo", "uevxo", "nitgx", "eoyaz", "qoarn", "mxdpo", "auzru", "diysk", "cwbvu", "jeymh", "iacod", "cawnd", "vgaez", "bgvvt", "tiido", "aorju", "qajyo", "ryken", "iabca", "tkhgs", "tbjqx", "mqtep", "fowhl", "fvhrk", "nqcpi", "mbzht", "uhfrw", "utial", "cpebh", "tnjoh", "jsymf", "vgevv", "mxduo", "gkiki", "bvyvm", "hqhrb", "isrps", "nqyqm", "dlast", "pxqme", "bpzsc", "vdlvr", "lhuly", "crorm", "bvqnu", "tpnyf", "thnuq", "swbnm", "cadba", "sbrmv", "wibky", "wlfsf", "wwvcs", "xffsa", "xkmqx", "xlsme", "ygrip", "ylgug", "ytula", "zehzu", "zsgto", "zzsba", "zzyyo"]

derMetric = DiarizationErrorRate(collar=0.25)
jerMetric = JaccardErrorRate(collar=0.25)

derMetricSO = DiarizationErrorRate(collar=0.25, skip_overlap=True)
jerMetricSO = JaccardErrorRate(collar=0.25, skip_overlap=True)

for idx, sample_id in enumerate(sample_ids):
    reference_file = os.path.join(reference_path, sample_id + '.rttm')
    hypotesis_file = os.path.join(hypothesis_path, sample_id + '.rttm')

    reference = load_rttm(reference_file)[sample_id]
    hypothesis = load_rttm(hypotesis_file)[sample_id]
    
    der = derMetric(reference, hypothesis, detailed = True)
    jer = jerMetric(reference, hypothesis)
    
    confusion = der['confusion']/der['total']
    false_alarm = der['false alarm']/der['total']
    missed_detection = der['missed detection']/der['total']
    
    derSO = derMetricSO(reference, hypothesis, detailed = True)
    jerSO = jerMetricSO(reference, hypothesis)
    
    confusionSO = derSO['confusion']/derSO['total']
    false_alarmSO = derSO['false alarm']/derSO['total']
    missed_detectionSO = derSO['missed detection']/derSO['total']
    print(idx,'/',len(sample_ids)-1)

print('--------------- rttm-colar-auto -----------------')
print(f'DER = {100 * der["diarization error rate"]:.1f}% JER = {100 * jer:.1f}%')
print(f'DER_conf = {100 * confusion:.1f}% DER_fa = {100 * false_alarm:.1f}% DER_md = {100 * missed_detection:.1f}%')
print('\nSkip Overlap')
print(f'DER = {100 * derSO["diarization error rate"]:.1f}% JER = {100 * jerSO:.1f}%')
print(f'DER_conf = {100 * confusionSO:.1f}% DER_fa = {100 * false_alarmSO:.1f}% DER_md = {100 * missed_detectionSO:.1f}%')

0 / 231
1 / 231
2 / 231
3 / 231
4 / 231
5 / 231
6 / 231
7 / 231
8 / 231
9 / 231
10 / 231
11 / 231
12 / 231
13 / 231
14 / 231
15 / 231
16 / 231
17 / 231
18 / 231
19 / 231
20 / 231
21 / 231
22 / 231
23 / 231
24 / 231
25 / 231
26 / 231
27 / 231
28 / 231
29 / 231
30 / 231
31 / 231
32 / 231
33 / 231
34 / 231
35 / 231
36 / 231
37 / 231
38 / 231
39 / 231
40 / 231
41 / 231
42 / 231
43 / 231
44 / 231
45 / 231
46 / 231
47 / 231
48 / 231
49 / 231
50 / 231
51 / 231
52 / 231
53 / 231
54 / 231
55 / 231
56 / 231
57 / 231
58 / 231
59 / 231
60 / 231
61 / 231
62 / 231
63 / 231
64 / 231
65 / 231
66 / 231
67 / 231
68 / 231
69 / 231
70 / 231
71 / 231
72 / 231
73 / 231
74 / 231
75 / 231
76 / 231
77 / 231
78 / 231
79 / 231
80 / 231
81 / 231
82 / 231
83 / 231
84 / 231
85 / 231
86 / 231
87 / 231
88 / 231
89 / 231
90 / 231
91 / 231
92 / 231
93 / 231
94 / 231
95 / 231
96 / 231
97 / 231
98 / 231
99 / 231
100 / 231
101 / 231
102 / 231
103 / 231
104 / 231
105 / 231
106 / 231
107 / 231
108 / 231
109 / 231
110 / 231


VoxConverse Test Set

------------------ rttm --------------------
DER = 57.7% JER = 90.4%
DER_conf = 53.1% DER_fa = 0.8% DER_md = 3.7%

Skip Overlap
DER = 56.3% JER = 90.3%
DER_conf = 54.5% DER_fa = 0.9% DER_md = 1.0%

--------------- rttm-collar ----------------
DER = 58.0% JER = 90.3%
DER_conf = 53.1% DER_fa = 1.7% DER_md = 3.2%

Skip Overlap
DER = 56.7% JER = 90.2%
DER_conf = 54.5% DER_fa = 1.8% DER_md = 0.4%

---------------- rttm-auto -----------------
DER = 58.0% JER = 92.6%
DER_conf = 53.5% DER_fa = 0.8% DER_md = 3.7%

Skip Overlap
DER = 56.2% JER = 92.4%
DER_conf = 54.4% DER_fa = 0.9% DER_md = 1.0%

------------- rttm-colar-auto --------------
DER = 58.2% JER = 92.6%
DER_conf = 53.6% DER_fa = 1.5% DER_md = 3.1%

Skip Overlap
DER = 56.5% JER = 92.4%
DER_conf = 54.5% DER_fa = 1.6% DER_md = 0.4%
