In [None]:
import logging
logging.disable(logging.CRITICAL)
from omegaconf import OmegaConf

from pyannote.metrics.diarization import DiarizationErrorRate

from nemo.utils import logging

In [None]:
from util.nemo_util import *
# from model.NeMo_diarizer import NeMoDiarizer

## Model Class

In [None]:
from util.nemo_util import *

class NeMoDiarizer(ClusteringDiarizer):
    def __init__(self, cfg: DictConfig, speaker_model=None):
        super().__init__(cfg, speaker_model)
        self.output_dir = cfg.diarizer.out_dir
        self.rttm_dir = os.path.join(self.output_dir, 'pred_rttms')
        self.json_dir = os.path.join(self.output_dir, 'pred_json')

        if not os.path.exists(self.output_dir):
            os.mkdir(self.output_dir)
        if not os.path.exists(self.json_dir):
            os.mkdir(self.json_dir)

        # Oracle speaker num
        self._cluster_params.oracle_num_speakers = True

        torch.set_default_device(self._cfg.device)

    def diarize(self, paths2audio_files: List[str], durations: List, num_speakers: List[int] = None, output_label: bool = True, batch_size: int = 0):
        """
        Diarize list of audio files in paths2audio_files
        Arguments:
            * paths2audio_files: list of audio files to be diarized
            * duration: duration of each audio file
            * num_speakers: if not None, then it corresponds to number of speakers in each of the audio input files
            * output_label: if True, saves text label file that can be used in Audacity application for visualization
            * batch_size: batch_sizeconsidered for extraction of speaker embedding and VAD computation
        """
        # setup manifest file
        config_setup(paths2audio_files, self._diarizer_params.manifest_filepath, durations, num_speakers=num_speakers)

        self._cluster_params.oracle_num_speakers = num_speakers is not None

        self._out_dir = self._diarizer_params.out_dir

        self._speaker_dir = os.path.join(self._diarizer_params.out_dir, 'speaker_outputs')

        if os.path.exists(self._speaker_dir):
            logging.warning("Deleting previous clustering diarizer outputs.")
            shutil.rmtree(self._speaker_dir, ignore_errors=True)
        os.makedirs(self._speaker_dir)

        if not os.path.exists(self._out_dir):
            os.mkdir(self._out_dir)

        self._vad_dir = os.path.join(self._out_dir, 'vad_outputs')
        self._vad_out_file = os.path.join(self._vad_dir, "vad_out.json")

        if batch_size:
            self._cfg.batch_size = batch_size

        if paths2audio_files:
            if type(paths2audio_files) is list:
                self._diarizer_params.manifest_filepath = os.path.join(self._out_dir, 'paths2audio_filepath.json')
                config_setup(paths2audio_files, self._diarizer_params.manifest_filepath, durations, num_speakers=num_speakers)
                # self.path2audio_files_to_manifest(paths2audio_files, self._diarizer_params.manifest_filepath)
            else:
                raise ValueError("paths2audio_files must be of type list of paths to file containing audio file")

        self.AUDIO_RTTM_MAP = audio_rttm_map(self._diarizer_params.manifest_filepath)

        out_rttm_dir = os.path.join(self._out_dir, 'pred_rttms')
        os.makedirs(out_rttm_dir, exist_ok=True)

        # Speech Activity Detection
        self._perform_speech_activity_detection()

        # Segmentation
        scales = self.multiscale_args_dict['scale_dict'].items()
        for scale_idx, (window, shift) in scales:

            # Segmentation for the current scale (scale_idx)
            self._run_segmentation(window, shift, scale_tag=f'_scale{scale_idx}')

            # Embedding Extraction for the current scale (scale_idx)
            self._extract_embeddings(self.subsegments_manifest_path, scale_idx, len(scales))

            self.multiscale_embeddings_and_timestamps[scale_idx] = [self.embeddings, self.time_stamps]

        embs_and_timestamps = get_embs_and_timestamps(
            self.multiscale_embeddings_and_timestamps, self.multiscale_args_dict
        )

        # Clustering
        all_reference, all_hypothesis = perform_clustering(
            embs_and_timestamps=embs_and_timestamps,
            AUDIO_RTTM_MAP=self.AUDIO_RTTM_MAP,
            out_rttm_dir=out_rttm_dir,
            clustering_params=self._cluster_params,
            device=self._speaker_model.device,
            verbose=self.verbose,
        )
        logging.info("Outputs are saved in {} directory".format(os.path.abspath(self._diarizer_params.out_dir)))


        # generate label file
        if output_label:
            for uniq_id, audio_rttm_values in self.AUDIO_RTTM_MAP.items():
            
                filename = audio_rttm_values.get('audio_filepath').split('/')[-1][:-4]
                labels = rttm_to_labels(os.path.join(self.rttm_dir, f'{filename}.rttm'))
                hypothesis = labels_to_pyannote_object(labels)

                last_label = {
                    'start' : None, 'end' : None, 'label' : None
                }

                with open(os.path.join(self.json_dir, f'{filename}_labels.txt'), 'w') as f:
                    for segment, track, label in hypothesis.itertracks(yield_label=True):
                        start, end = segment.start, segment.end
                        if label == last_label['label']:
                            last_label['end'] = end
                            continue
                        # write previous label
                        if last_label['label'] is not None:
                            f.write(f"{last_label['start']}\t{last_label['end']}\t{last_label['label']}\n")
                        last_label = {
                            'start' : start, 'end' : end, 'label' : label
                        }
                    f.write(f"{last_label['start']}\t{last_label['end']}\t{last_label['label']}\n")

    def _run_vad(self, manifest_file):
        """
        Run voice activity detection. 
        Get log probability of voice activity detection and smoothes using the post processing parameters. 
        Using generated frame level predictions generated manifest file for later speaker embedding extraction.
        input:
        manifest_file (str) : Manifest file containing path to audio file and label as infer

        """

        shutil.rmtree(self._vad_dir, ignore_errors=True)
        os.makedirs(self._vad_dir)

        self._vad_model.eval()

        time_unit = int(self._vad_window_length_in_sec / self._vad_shift_length_in_sec)
        trunc = int(time_unit / 2)
        trunc_l = time_unit - trunc
        all_len = 0
        data = []
        for line in open(manifest_file, 'r', encoding='utf-8'):
            file = json.loads(line)['audio_filepath']
            data.append(get_uniqname_from_filepath(file))

        status = get_vad_stream_status(data)
        for i, test_batch in enumerate(
            tqdm(self._vad_model.test_dataloader(), desc='vad', leave=True, disable=not self.verbose)
        ):
            test_batch = [x.to(self._vad_model.device) for x in test_batch]
            with autocast():
                log_probs = self._vad_model(input_signal=test_batch[0], input_signal_length=test_batch[1])
                probs = torch.softmax(log_probs, dim=-1)
                pred = probs[:, 1]
                if status[i] == 'start':
                    to_save = pred[:-trunc]
                elif status[i] == 'next':
                    to_save = pred[trunc:-trunc_l]
                elif status[i] == 'end':
                    to_save = pred[trunc_l:]
                else:
                    to_save = pred
                all_len += len(to_save)
                outpath = os.path.join(self._vad_dir, data[i] + ".frame")
                with open(outpath, "a", encoding='utf-8') as fout:
                    for f in range(len(to_save)):
                        fout.write('{0:0.4f}\n'.format(to_save[f]))
            del test_batch
            if status[i] == 'end' or status[i] == 'single':
                all_len = 0

        if not self._vad_params.smoothing:
            # Shift the window by 10ms to generate the frame and use the prediction of the window to represent the label for the frame;
            self.vad_pred_dir = self._vad_dir
            frame_length_in_sec = self._vad_shift_length_in_sec
        else:
            # Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments.
            # smoothing_method would be either in majority vote (median) or average (mean)
            logging.info("Generating predictions with overlapping input segments")
            smoothing_pred_dir = generate_overlap_vad_seq(
                frame_pred_dir=self._vad_dir,
                smoothing_method=self._vad_params.smoothing,
                overlap=self._vad_params.overlap,
                window_length_in_sec=self._vad_window_length_in_sec,
                shift_length_in_sec=self._vad_shift_length_in_sec,
                num_workers=self._cfg.num_workers,
            )
            self.vad_pred_dir = smoothing_pred_dir
            frame_length_in_sec = 0.01

        logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.")

        vad_params = self._vad_params if isinstance(self._vad_params, (DictConfig, dict)) else self._vad_params.dict()
        table_out_dir = generate_vad_segment_table(
            vad_pred_dir=self.vad_pred_dir,
            postprocessing_params=vad_params,
            frame_length_in_sec=frame_length_in_sec,
            num_workers=self._cfg.num_workers,
            out_dir=self._vad_dir,
        )

        AUDIO_VAD_RTTM_MAP = {}
        for key in self.AUDIO_RTTM_MAP:
            if os.path.exists(os.path.join(table_out_dir, key + ".txt")):
                AUDIO_VAD_RTTM_MAP[key] = deepcopy(self.AUDIO_RTTM_MAP[key])
                AUDIO_VAD_RTTM_MAP[key]['rttm_filepath'] = os.path.join(table_out_dir, key + ".txt")
            else:
                logging.warning(f"no vad file found for {key} due to zero or negative duration")

        write_rttm2manifest(AUDIO_VAD_RTTM_MAP, self._vad_out_file)
        self._speaker_manifest_path = self._vad_out_file

    def _extract_embeddings(self, manifest_file: str, scale_idx: int, num_scales: int):
        """
        This method extracts speaker embeddings from segments passed through manifest_file
        Optionally you may save the intermediate speaker embeddings for debugging or any use. 
        """
        logging.info("Extracting embeddings for Diarization")
        self._setup_spkr_test_data(manifest_file)
        self.embeddings = {}
        self._speaker_model.eval()
        self.time_stamps = {}

        all_embs = torch.empty([0]).cpu()
        for test_batch in tqdm(
            self._speaker_model.test_dataloader(),
            desc=f'[{scale_idx+1}/{num_scales}] extract embeddings',
            leave=True,
            disable=not self.verbose,
        ):
            test_batch = [x.to(self._speaker_model.device) for x in test_batch]
            audio_signal, audio_signal_len, labels, slices = test_batch
            with autocast():
                _, embs = self._speaker_model.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)
                emb_shape = embs.shape[-1]
                embs = embs.view(-1, emb_shape)
                all_embs = torch.cat((all_embs, embs.cpu().detach()), dim=0)
            del test_batch

        with open(manifest_file, 'r', encoding='utf-8') as manifest:
            for i, line in enumerate(manifest.readlines()):
                line = line.strip()
                dic = json.loads(line)
                uniq_name = get_uniqname_from_filepath(dic['audio_filepath'])
                if uniq_name in self.embeddings:
                    self.embeddings[uniq_name] = torch.cat((self.embeddings[uniq_name], all_embs[i].view(1, -1)))
                else:
                    self.embeddings[uniq_name] = all_embs[i].view(1, -1)
                if uniq_name not in self.time_stamps:
                    self.time_stamps[uniq_name] = []
                start = dic['offset']
                end = start + dic['duration']
                self.time_stamps[uniq_name].append([start, end])

        if self._speaker_params.save_embeddings:
            embedding_dir = os.path.join(self._speaker_dir, 'embeddings')
            if not os.path.exists(embedding_dir):
                os.makedirs(embedding_dir, exist_ok=True)

            prefix = get_uniqname_from_filepath(manifest_file)
            name = os.path.join(embedding_dir, prefix)
            self._embeddings_file = name + f'_embeddings.pkl'
            pkl.dump(self.embeddings, open(self._embeddings_file, 'wb'))
            logging.info("Saved embedding files to {}".format(embedding_dir))


## Initialize Model

We first set configuration

In [3]:
device_id = 1
MODEL_CONFIG = os.path.join('config','model_config.yaml')
config = OmegaConf.load(MODEL_CONFIG)
config.device = f'cuda:{device_id}'
config.verbose = True
model = NeMoDiarizer(cfg=config)
torch.set_default_device(config.device)

In [5]:
audio_dir = 'conversations'
filenames = []
for file in os.listdir(audio_dir):
    if file[-4:] == '.wav':
        filenames.append(os.path.join(audio_dir, file))

# get duration
durations = []

# get number of speakers
num_speakers = []

for filename in filenames:
    json_path = f'{filename[:-4]}.json'
    raw_json = json.loads(open(json_path, 'r').read())
    num_speaker = len(raw_json['participants'])
    # for p in raw_json['participants']:
    #     if p['name'] in ('Hearth', 'participant', 'Participant'):
    #         continue
    num_speakers.append(num_speaker)
    durations.append(raw_json['duration'])

In [6]:
model.diarize(filenames, durations)

splitting manifest: 100%|██████████| 21/21 [00:10<00:00,  1.97it/s]
vad: 100%|██████████| 1760/1760 [06:12<00:00,  4.73it/s]
                                                                 

KeyboardInterrupt: 

In [None]:
diarize_performance = {}
for file in filenames:
    # Merge discontinued labels
    merged_label = Annotation()

    filename = file.split('/')[-1]

    with open(os.path.join('outputs', 'pred_json', f'{filename[:-4]}_labels.txt')) as f:
        for line in f.readlines():
            start, end, speaker = line.split()
            start, end = float(start), float(end)
            merged_label[Segment(start, end)] = speaker

    # Evaluate metrics using merged label
    true_labels = rttm_to_labels(os.path.join(audio_dir, f'{filename[:-4]}.rttm'))
    reference = labels_to_pyannote_object(true_labels)

    performance = DiarizationErrorRate().compute_components(reference, merged_label)
    metrics = ['confusion', 'missed detection', 'false alarm']
    for metric in metrics:
        performance[metric] /= performance['total']
    performance['DER'] = sum(performance[metric] for metric in metrics)
    diarize_performance[filename] = performance



In [None]:
import pandas as pd

df = {x : [] for x in performance}
index = []

for filename in diarize_performance:
    for x in diarize_performance[filename]:
        df[x].append(diarize_performance[filename][x])
    index.append(filename)

df = pd.DataFrame(df, index=index)

df.to_csv('diarize_performance_no_oracle.csv')

In [None]:
df[df["DER"] <= 0.15]

Unnamed: 0,confusion,total,correct,false alarm,missed detection,DER
conversation-81.wav,0.020155,4999.4,4681.312,0.023827,0.04347,0.087452
conversation-70.wav,0.050032,2015.3,1825.4,0.021431,0.044197,0.11566
conversation-40.wav,0.045632,4483.7,4181.48,0.017209,0.021772,0.084613
conversation-67.wav,0.054663,5250.7,4700.945,0.027641,0.050038,0.132342
conversation-42.wav,0.023578,5551.2,5240.178,0.025278,0.03245,0.081306
conversation-80.wav,0.019234,3192.0,3000.495,0.013515,0.040761,0.07351
conversation-41.wav,0.042411,3882.2,3586.005,0.030042,0.033885,0.106338
conversation-7.wav,0.051628,2585.4,2321.58,0.011832,0.050414,0.113874


In [3]:
import pandas as pd

df = pd.read_csv("diarize_performance.csv")

In [6]:
df[df["confusion"] > 0.1]

Unnamed: 0.1,Unnamed: 0,total,missed detection,false alarm,correct,confusion,DER
3,conversation-55.wav,5908.2,0.104163,0.016196,4691.321,0.101801,0.22216
5,conversation-65.wav,5459.8,0.03044,0.029127,4586.261,0.129554,0.189121
8,conversation-79.wav,2686.9,0.088853,0.021623,2164.915,0.105417,0.215894
9,conversation-66.wav,5044.9,0.031653,0.027466,4351.357,0.105821,0.16494
10,conversation-73.wav,2866.3,0.138907,0.062837,1595.22,0.304549,0.506294
12,conversation-46.wav,4386.7,0.044642,0.03016,2704.77,0.338773,0.413576
15,conversation-39.wav,3350.16,3e-06,0.102231,1828.23,0.454283,0.556517
16,conversation-82.wav,4582.1,0.023462,0.015482,3869.37,0.132084,0.171029
18,conversation-57.wav,5895.6,0.025938,0.014189,4341.423,0.237679,0.277805


In [None]:
df.describe()

Unnamed: 0,confusion,total,correct,false alarm,missed detection,DER
count,21.0,21.0,21.0,21.0,21.0,21.0
mean,0.128494,3955.793333,3257.507619,0.027606,0.065733,0.221833
std,0.12155,1443.440785,1331.835534,0.016045,0.041605,0.153253
min,0.019234,244.8,113.915,0.009092,0.021772,0.07351
25%,0.045632,2866.3,2321.58,0.015528,0.033885,0.113874
50%,0.085201,4036.5,3105.945,0.023458,0.050038,0.174643
75%,0.130677,5044.9,4593.031,0.030042,0.088008,0.218585
max,0.439641,5908.2,5240.178,0.066113,0.169649,0.546385
