In [1]:
import os
root = '/home/incomplete/ai/'
project_root = os.path.join(root, 'speaker-diarization')

os.chdir(project_root)
import sys
sys.path.append(project_root)

import matplotlib.pyplot as plt
plt.style.use('dark_background')

In [2]:
import json
from glob import glob
import os
import numpy as np
import copy
import itertools

import result.rttm
import rttm_tools
from typing import List, Tuple, Dict, Optional

In [3]:
PUNCTUATIONS = ' 。，？！、'

# reviewed
def find_labels(start_seconds, stop_seconds, rttms):
    intersections = []
    for rttm in rttms:
        parts = result.rttm.intersection(start_seconds, stop_seconds, 'l', rttm.start, rttm.stop, 'r')
        if parts is None:
            continue
        else:
            for (istart, istop, tag) in parts:
                if tag == 'intersection':
                    intersections.append((istop - istart, (istart, istop), rttm.speaker))
    intersections.sort(key=lambda x: x[0], reverse=True)
    if intersections:
        return intersections
    else:
        return None

MERGE_THRESHOLD_MS = 250

# reviewed
def pick_words_of_speaker(words_with_labels, speaker) -> List[Tuple[float, float]]:
    """
    Filter words of speaker, keep the longest duration, result in seconds
    """
    res = []
    for _word, labels in words_with_labels:
        labels = [
            ends
            for _duration, ends, label in sorted(labels, reverse=True)
            if label == speaker
        ]
        if labels:
            start_s, stop_s = labels[0]
            res.append((start_s * 1000, stop_s * 1000))
    return res



# obtain label from human annotation
for path in glob('result/vad_xf/*.transcript.json'):
    print('processing', path)
    with open(path) as f:
        sentences = json.load(f)

    # load human labels
    basename = os.path.basename(path)
    file_id = basename.split('.')[0]
    rttm_path = os.path.join('result/label_rttm_raw', f'{file_id}.rttm')
    rttms = result.rttm.load_rttm(rttm_path)

    word_id = 0
    words_dropped = set()
    # tag each word with speaker,
    # words with no speaker information are dropped
    words_with_labels = []
    for sent in sentences:
        for word in sent['words']:
            word['id'] = word_id
            word_id += 1
            # punctuations are ignored
            if word['text'] in PUNCTUATIONS:
                words_dropped.add(word['id'])
                continue
            w_start = word['start_ms_audio_time']
            w_stop = word['stop_ms_audio_time']
            # if the word is labeled, then keep it,
            # otherwise drop it
            labels = find_labels(w_start/1000, w_stop/1000, rttms)
            if labels:
                # labels = list(set([lbl for _, lbl in labels]))
                words_with_labels.append((word, labels))
            else:
                # words with no labels are dropped
                words_dropped.add(word['id'])

    # sanity check
    old_len = np.sum([len(x['words']) for x in sentences])
    new_len = len(words_with_labels) + len(words_dropped)
    assert old_len == new_len

    all_labels = set(
        [label
         for _word, labels in words_with_labels
         for _duration, _ends, label in labels
        ]
    )
    # sanity check
    assert all_labels == set(['s', 't'])

    # t_words = [
    #     (w['start_ms_audio_time'], w['stop_ms_audio_time'])
    #     for w, labels in words_with_labels
    #     if 't' in labels
    # ]
    # s_words = [
    #     (w['start_ms_audio_time'], w['stop_ms_audio_time'])
    #     for w, labels in words_with_labels
    #     if 's' in labels
    # ]
    t_words = pick_words_of_speaker(words_with_labels, 't')
    s_words = pick_words_of_speaker(words_with_labels, 's')
    # sanity check
    # pigeon hole
    assert len(t_words) + len(s_words) >= len(words_with_labels)
    assert len(t_words) + len(s_words) <= 2 * old_len
    t_sents = rttm_tools.merge_intervals(
        t_words, [], MERGE_THRESHOLD_MS
    )
    s_sents = rttm_tools.merge_intervals(
        s_words, [], MERGE_THRESHOLD_MS
    )
    rows = []
    for sents, tag in [(t_sents, 't'), (s_sents, 's')]:
        for (start, stop), tag in zip(sents, itertools.repeat(tag)):
            # remove short segs (FOR TRY ONLY)
            if stop - start < 1500:
                continue
            rttm = result.rttm.Rttm(
                file_id, start/1000, (stop-start)/1000, tag
            )
            rows.append(rttm)
    rttm_out_path = os.path.join('result/label_rttm_xf', f'{file_id}.rttm')
    result.rttm.write_rttm(rows, rttm_out_path)
    # write label for audacity
    result.rttm.write_tsv(
        [
            (r.start, r.stop, r.speaker)
            for r in rows
        ],
        f'{rttm_out_path}.txt'
    )

    # make vad
    # vad_words = [
    #     (w['start_ms_audio_time'], w['stop_ms_audio_time'])
    #     for w, _ in words_with_labels
    # ]
    # sents = rttm_tools.merge_intervals(vad_words, [], MERGE_THRESHOLD_MS)
    # sents = [
    #     (start/1000, stop/1000, 'v')
    #     for start, stop in sents
    # ]

    # make vad
    # this is insane (FOR TRY ONLY)
    # translate the rttm directly to vad
    sents = []
    # here the length filtering is done before merging,
    # to keep it in-sync with the rttm
    sents.extend([(start, stop) for start, stop in t_sents if stop - start >= 1500])
    sents.extend([(start, stop) for start, stop in s_sents if stop - start >= 1500])
    sents = rttm_tools.merge_intervals(sents, [], 0)
    sents = [
        (start/1000, stop/1000)
        for start, stop in sents
    ]

    vad_out_path = os.path.join('result/vad_xf', f'{file_id}.lab')
    result.rttm.write_lab(sents, vad_out_path)
    # write label for audacity
    result.rttm.write_tsv(
        [(x, y, 'v') for x, y in sents],
        vad_out_path + '.txt'
    )


    # label each word
    words = []
    for sent in sentences:
        for word in sent['words']:
            words.append(
                (
                    word['start_ms_audio_time']/1000,
                    word['stop_ms_audio_time']/1000,
                    word['text']
                )
            )
    result.rttm.write_tsv(words, os.path.join('result/vad_xf', f'{file_id}.word.txt'))

processing result/vad_xf/F_clean_hdgs.wav.transcript.json
processing result/vad_xf/M_clean_yztj.wav.transcript.json
processing result/vad_xf/M_echo_sjzm.wav.transcript.json


In [None]:
for path in glob('result/vad_xf/*.lab'):
    lines = result.rttm.load_lab(path)

    basename = os.path.basename(path)
    file_id = basename.split('.')[0]

    lengths = [
        line[1] - line[0]
        for line in lines
    ]
    # median = np.median(lengths)
    short_threshold = 0.75
    print(file_id, 'short_threshold', short_threshold)
    shorts = [x for x in lengths if x < short_threshold]
    print('counts', len(shorts), len(lengths), len(shorts)/len(lengths))
    print('durations', np.sum(shorts), np.sum(lengths), np.sum(shorts)/np.sum(lengths))
    print('mean duration', np.mean(lengths))
    plt.figure()
    plt.hist(shorts, bins=100)
    plt.show()