### Installing Dependencies

In [None]:
!pip install git+https://github.com/m-bain/whisperX.git@a5dca2cc65b1a37f32a347e574b2c56af3a7434a
!pip install --no-build-isolation nemo_toolkit[asr]==1.21.0
!pip install git+https://github.com/facebookresearch/demucs#egg=demucs
!pip install deepmultilingualpunctuation
!pip install wget pydub
!pip install numba==0.58.0
!pip install unidecode

Collecting git+https://github.com/m-bain/whisperX.git@a5dca2cc65b1a37f32a347e574b2c56af3a7434a
  Cloning https://github.com/m-bain/whisperX.git (to revision a5dca2cc65b1a37f32a347e574b2c56af3a7434a) to /tmp/pip-req-build-mp1g___f
  Running command git clone --filter=blob:none --quiet https://github.com/m-bain/whisperX.git /tmp/pip-req-build-mp1g___f
  Running command git rev-parse -q --verify 'sha^a5dca2cc65b1a37f32a347e574b2c56af3a7434a'
  Running command git fetch -q https://github.com/m-bain/whisperX.git a5dca2cc65b1a37f32a347e574b2c56af3a7434a
  Running command git checkout -q a5dca2cc65b1a37f32a347e574b2c56af3a7434a
  Resolved https://github.com/m-bain/whisperX.git to commit a5dca2cc65b1a37f32a347e574b2c56af3a7434a
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting faster-whisper@ git+https://github.com/SYSTRAN/faster-whisper.git@0.10.0 (from whisperx==3.1.1)
  Cloning https://github.com/SYSTRAN/faster-whisper.git (to revision 0.10.0) to /tmp/pip-install-jnl5k27e/fast

Collecting nemo_toolkit[asr]==1.21.0
  Downloading nemo_toolkit-1.21.0-py3-none-any.whl (2.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
Collecting onnx>=1.7.0 (from nemo_toolkit[asr]==1.21.0)
  Downloading onnx-1.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.7/15.7 MB[0m [31m46.2 MB/s[0m eta [36m0:00:00[0m
Collecting wget (from nemo_toolkit[asr]==1.21.0)
  Downloading wget-3.2.zip (10 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting braceexpand (from nemo_toolkit[asr]==1.21.0)
  Downloading braceexpand-0.1.7-py2.py3-none-any.whl (5.9 kB)
Collecting g2p-en (from nemo_toolkit[asr]==1.21.0)
  Downloading g2p_en-2.1.0-py3-none-any.whl (3.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m80.4 MB/s[0m eta [36m0:00:00[0m
Collecting jiwer (from nemo_toolkit

### Import libraries

In [None]:
import os
import wget
from omegaconf import OmegaConf
import json
import shutil
from faster_whisper import WhisperModel
import whisperx
import torch
from pydub import AudioSegment
from nemo.collections.asr.models.msdd_models import NeuralDiarizer
from deepmultilingualpunctuation import PunctuationModel
import re
import logging
import nltk
from whisperx.alignment import DEFAULT_ALIGN_MODELS_HF, DEFAULT_ALIGN_MODELS_TORCH
from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE
import unidecode
from unidecode import unidecode
import pathlib
from pathlib import Path

  warn(
[NeMo W 2024-01-24 20:08:27 transformer_bpe_models:59] Could not import NeMo NLP collection which is required for speech translation model.


### Helper Functions

In [None]:
#represents different languages that are compatible
punct_model_langs = [
    "en",
    "fr",
    "de",
    "es",
    "it",
    "nl",
    "pt",
    "bg",
    "pl",
    "cs",
    "sk",
    "sl",
]
wav2vec2_langs = list(DEFAULT_ALIGN_MODELS_TORCH.keys()) + list(
    DEFAULT_ALIGN_MODELS_HF.keys()
)

whisper_langs = sorted(LANGUAGES.keys()) + sorted(
    [k.title() for k in TO_LANGUAGE_CODE.keys()]
)


def create_config(output_dir, DOMAIN_TYPE = "telephonic"): # DOMAIN_TYPE: can be meeting, telephonic, or general based on domain type of the audio file
    CONFIG_FILE_NAME = f"diar_infer_{DOMAIN_TYPE}.yaml"
    CONFIG_URL = f"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/{CONFIG_FILE_NAME}"
    MODEL_CONFIG = os.path.join(output_dir, CONFIG_FILE_NAME)
    if not os.path.exists(MODEL_CONFIG):
        MODEL_CONFIG = wget.download(CONFIG_URL, output_dir)

    config = OmegaConf.load(MODEL_CONFIG)

    data_dir = os.path.join(output_dir, "data")
    os.makedirs(data_dir, exist_ok=True)

    meta = {
        "audio_filepath": os.path.join(output_dir, "mono_file.wav"),
        "offset": 0,
        "duration": None,
        "label": "infer",
        "text": "-",
        "rttm_filepath": None,
        "uem_filepath": None,
    }
    with open(os.path.join(data_dir, "input_manifest.json"), "w") as fp:
        json.dump(meta, fp)
        fp.write("\n")

    pretrained_vad = "vad_multilingual_marblenet" #VAD model
    pretrained_speaker_model = "titanet_large"
    config.num_workers = 0  # Workaround for multiprocessing hanging with ipython issue
    config.diarizer.manifest_filepath = os.path.join(data_dir, "input_manifest.json")
    config.diarizer.out_dir = (
        output_dir  # Directory to store intermediate files and prediction outputs
    )

    config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model
    config.diarizer.oracle_vad = (
        False  # compute VAD provided with model_path to vad config
    )
    config.diarizer.clustering.parameters.oracle_num_speakers = False

    # Here, we use our in-house pretrained NeMo VAD model
    config.diarizer.vad.model_path = pretrained_vad
    config.diarizer.vad.parameters.onset = 0.8
    config.diarizer.vad.parameters.offset = 0.6
    config.diarizer.vad.parameters.pad_offset = -0.05
    config.diarizer.msdd_model.model_path = (
        "diar_msdd_telephonic"  # Telephonic speaker diarization model
    )

    return config

#getting start word
def get_word_ts_anchor(s, e, option="start"):
    if option == "end":
        return e
    elif option == "mid":
        return (s + e) / 2
    return s

#mapping the words with the speaker at starting
def get_words_speaker_mapping(wrd_ts, spk_ts, word_anchor_option="start"):
    s, e, sp = spk_ts[0]
    wrd_pos, turn_idx = 0, 0
    wrd_spk_mapping = []
    for wrd_dict in wrd_ts:
        ws, we, wrd = (
            int(wrd_dict["start"] * 1000),
            int(wrd_dict["end"] * 1000),
            wrd_dict["word"],
        )
        wrd_pos = get_word_ts_anchor(ws, we, word_anchor_option)
        while wrd_pos > float(e):
            turn_idx += 1
            turn_idx = min(turn_idx, len(spk_ts) - 1)
            s, e, sp = spk_ts[turn_idx]
            if turn_idx == len(spk_ts) - 1:
                e = get_word_ts_anchor(ws, we, option="end")
        wrd_spk_mapping.append(
            {"word": wrd, "start_time": ws, "end_time": we, "speaker": sp}
        )
    return wrd_spk_mapping

#including all the punctuation marks to end the sentence
sentence_ending_punctuations = ".?!"

#getting index of the starting word
def get_first_word_idx_of_sentence(word_idx, word_list, speaker_list, max_words):
    is_word_sentence_end = (
        lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations
    )
    left_idx = word_idx
    while (
        left_idx > 0
        and word_idx - left_idx < max_words
        and speaker_list[left_idx - 1] == speaker_list[left_idx]
        and not is_word_sentence_end(left_idx - 1)
    ):
        left_idx -= 1

    return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1

#getting index of the end word
def get_last_word_idx_of_sentence(word_idx, word_list, max_words):
    is_word_sentence_end = (
        lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations
    )
    right_idx = word_idx
    while (
        right_idx < len(word_list)
        and right_idx - word_idx < max_words
        and not is_word_sentence_end(right_idx)
    ):
        right_idx += 1

    return (
        right_idx
        if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx)
        else -1
    )

#realining the words
def get_realigned_ws_mapping_with_punctuation(
    word_speaker_mapping, max_words_in_sentence=50
):
    is_word_sentence_end = (
        lambda x: x >= 0
        and word_speaker_mapping[x]["word"][-1] in sentence_ending_punctuations
    )
    wsp_len = len(word_speaker_mapping)

    words_list, speaker_list = [], []
    for k, line_dict in enumerate(word_speaker_mapping):
        word, speaker = line_dict["word"], line_dict["speaker"]
        words_list.append(word)
        speaker_list.append(speaker)

    k = 0
    while k < len(word_speaker_mapping):
        line_dict = word_speaker_mapping[k]
        if (
            k < wsp_len - 1
            and speaker_list[k] != speaker_list[k + 1]
            and not is_word_sentence_end(k)
        ):
            left_idx = get_first_word_idx_of_sentence(
                k, words_list, speaker_list, max_words_in_sentence
            )
            right_idx = (
                get_last_word_idx_of_sentence(
                    k, words_list, max_words_in_sentence - k + left_idx - 1
                )
                if left_idx > -1
                else -1
            )
            if min(left_idx, right_idx) == -1:
                k += 1
                continue

            spk_labels = speaker_list[left_idx : right_idx + 1]
            mod_speaker = max(set(spk_labels), key=spk_labels.count)
            if spk_labels.count(mod_speaker) < len(spk_labels) // 2:
                k += 1
                continue

            speaker_list[left_idx : right_idx + 1] = [mod_speaker] * (
                right_idx - left_idx + 1
            )
            k = right_idx

        k += 1

    k, realigned_list = 0, []
    while k < len(word_speaker_mapping):
        line_dict = word_speaker_mapping[k].copy()
        line_dict["speaker"] = speaker_list[k]
        realigned_list.append(line_dict)
        k += 1

    return realigned_list

#mapping the sentences with the speaker
def get_sentences_speaker_mapping(word_speaker_mapping, spk_ts):
    sentence_checker = nltk.tokenize.PunktSentenceTokenizer().text_contains_sentbreak
    s, e, spk = spk_ts[0]
    prev_spk = spk

    snts = []
    snt = {"speaker": f"Speaker {spk}", "start_time": s, "end_time": e, "text": ""}

    for wrd_dict in word_speaker_mapping:
        wrd, spk = wrd_dict["word"], wrd_dict["speaker"]
        s, e = wrd_dict["start_time"], wrd_dict["end_time"]
        if spk != prev_spk or sentence_checker(snt["text"] + " " + wrd):
            snts.append(snt)
            snt = {
                "speaker": f"Speaker {spk}",
                "start_time": s,
                "end_time": e,
                "text": "",
            }
        else:
            snt["end_time"] = e
        snt["text"] += wrd + " "
        prev_spk = spk

    snts.append(snt)
    return snts


def get_speaker_aware_transcript(sentences_speaker_mapping, f):
    previous_speaker = sentences_speaker_mapping[0]["speaker"]
    f.write(f"{previous_speaker}: ")

    for sentence_dict in sentences_speaker_mapping:
        speaker = sentence_dict["speaker"]
        sentence = sentence_dict["text"]

        # If this speaker doesn't match the previous one, start a new paragraph
        if speaker != previous_speaker:
            f.write(f"\n\n{speaker}: ")
            previous_speaker = speaker

        # No matter what, write the current sentence
        f.write(sentence + " ")


def format_timestamp(
    milliseconds: float, always_include_hours: bool = False, decimal_marker: str = "."
):
    assert milliseconds >= 0, "non-negative timestamp expected"

    hours = milliseconds // 3_600_000
    milliseconds -= hours * 3_600_000

    minutes = milliseconds // 60_000
    milliseconds -= minutes * 60_000

    seconds = milliseconds // 1_000
    milliseconds -= seconds * 1_000

    hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
    return (
        f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
    )

#write transcript to a file
def write_srt(transcript, file):
    for i, segment in enumerate(transcript, start=1):
        print(
            f"{i}\n"
            f"{format_timestamp(segment['start_time'], always_include_hours=True, decimal_marker=',')} --> "
            f"{format_timestamp(segment['end_time'], always_include_hours=True, decimal_marker=',')}\n"
            f"{segment['speaker']}: {segment['text'].strip().replace('-->', '->')}\n",
            file=file,
            flush=True,
        )


#defining your tokenization
def find_numeral_symbol_tokens(tokenizer):
    numeral_symbol_tokens = [
        -1,
    ]
    for token, token_id in tokenizer.get_vocab().items():
        has_numeral_symbol = any(c in "0123456789%$£" for c in token)
        if has_numeral_symbol:
            numeral_symbol_tokens.append(token_id)
    return numeral_symbol_tokens


#next sentence start time
def _get_next_start_timestamp(word_timestamps, current_word_index):
    # if current word is the last word
    if current_word_index == len(word_timestamps) - 1:
        return word_timestamps[current_word_index]["start"]

    next_word_index = current_word_index + 1
    while current_word_index < len(word_timestamps) - 1:
        if word_timestamps[next_word_index].get("start") is None:
            # if next word doesn't have a start timestamp
            # merge it with the current word and delete it
            word_timestamps[current_word_index]["word"] += (
                " " + word_timestamps[next_word_index]["word"]
            )

            word_timestamps[next_word_index]["word"] = None
            next_word_index += 1

        else:
            return word_timestamps[next_word_index]["start"]

#data-preprocessing (handling missing data)
def filter_missing_timestamps(word_timestamps):
    # handle the first and last word
    if word_timestamps[0].get("start") is None:
        word_timestamps[0]["start"] = 0
        word_timestamps[0]["end"] = _get_next_start_timestamp(word_timestamps, 0)

    result = [
        word_timestamps[0],
    ]

    for i, ws in enumerate(word_timestamps[1:], start=1):
        # if ws doesn't have a start and end
        # use the previous end as start and next start as end
        if ws.get("start") is None and ws.get("word") is not None:
            ws["start"] = word_timestamps[i - 1]["end"]
            ws["end"] = _get_next_start_timestamp(word_timestamps, i)

        if ws["word"] is not None:
            result.append(ws)
    return result


def cleanup(path: str):
    """path could either be relative or absolute."""
    # check if file or directory exists
    if os.path.isfile(path) or os.path.islink(path):
        # remove file
        os.remove(path)
    elif os.path.isdir(path):
        # remove directory and all its content
        shutil.rmtree(path)
    else:
        raise ValueError("Path {} is not a file or dir.".format(path))

#Process the language argument to make sure it's valid and convert language names to language codes.
def process_language_arg(language: str, model_name: str):
    if language is not None:
        language = language.lower()
    if language not in LANGUAGES:
        if language in TO_LANGUAGE_CODE:
            language = TO_LANGUAGE_CODE[language]
        else:
            raise ValueError(f"Unsupported language: {language}")

    if model_name.endswith(".en") and language != "en":
        if language is not None:
            logging.warning(
                f"{model_name} is an English-only model but received '{language}'; using English instead."
            )
        language = "en"
    return language


def transcribe(
    audio_file: str,
    language: str,
    model_name: str,
    compute_dtype: str,
    suppress_numerals: bool,
    device: str,
):
    from faster_whisper import WhisperModel
    from helpers import find_numeral_symbol_tokens, wav2vec2_langs

    # Faster Whisper non-batched
    # Run on GPU with FP16
    whisper_model = WhisperModel(model_name, device=device, compute_type=compute_dtype)

    # or run on GPU with INT8
    # model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
    # or run on CPU with INT8
    # model = WhisperModel(model_size, device="cpu", compute_type="int8")

    if suppress_numerals:
        numeral_symbol_tokens = find_numeral_symbol_tokens(whisper_model.hf_tokenizer)
    else:
        numeral_symbol_tokens = None

    if language is not None and language in wav2vec2_langs:
        word_timestamps = False
    else:
        word_timestamps = True

    segments, info = whisper_model.transcribe(
        audio_file,
        language=language,
        beam_size=5,
        word_timestamps=word_timestamps,  # TODO: disable this if the language is supported by wav2vec2
        suppress_tokens=numeral_symbol_tokens,
        vad_filter=True,
    )
    whisper_results = []
    for segment in segments:
        whisper_results.append(segment._asdict())
    # clear gpu vram
    del whisper_model
    torch.cuda.empty_cache()
    return whisper_results, language


def transcribe_batched(
    audio_file: str,
    language: str,
    batch_size: int,
    model_name: str,
    compute_dtype: str,
    suppress_numerals: bool,
    device: str,
):
    import whisperx

    # Faster Whisper batched
    whisper_model = whisperx.load_model(
        model_name,
        device,
        compute_type=compute_dtype,
        asr_options={"suppress_numerals": suppress_numerals},
    )
    audio = whisperx.load_audio(audio_file)
    result = whisper_model.transcribe(audio, language=language, batch_size=batch_size)
    del whisper_model
    torch.cuda.empty_cache()
    return result["segments"], result["language"]

In [None]:
# no space, punctuation, accent in lower string
def cleanString(string):
    cleanString = unidecode(string)
    # cleanString = re.sub('\W+','_', cleanString)
    cleanString = re.sub(r'[^\w\s]','',cleanString)
    cleanString = cleanString.replace(" ", "_")
    return cleanString.lower()

# rename audio filename to get name without accent, no space, in lower case
def rename_file(filepath):
    suffix = Path(filepath).suffix
    if str(Path(filepath).parent) != ".":
        new_filepath = str(Path(filepath).parent) + cleanString(filepath.replace(suffix, "")) + suffix
    else:
        new_filepath = cleanString(filepath.replace(suffix, "")) + suffix
    os.rename(filepath, new_filepath)
    return new_filepath

## Setup options

In [None]:
# Name of the audio file
audio_path = "speech.m4a"

# rename audio filename if necessary to get string without accent, space, in lower case
audio_path = rename_file(audio_path)

In [None]:
# (Option) Whether to enable music removal from speech, helps increase diarization quality but uses alot of ram
enable_stemming = False

# (choose from 'tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large-v2', 'large-v3', 'large')
# whisper_model_name = "large-v2"
whisper_model_name = "large-v3"

# replaces numerical digits with their pronounciation, increases diarization accuracy
suppress_numerals = True

batch_size = 8

language = None  # autodetect language

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# check device
device

'cuda'

# Processing

(Optional) Separating music from speech using Demucs

---

By isolating the vocals from the rest of the audio, it becomes easier to identify and track individual speakers based on the spectral and temporal characteristics of their speech signals. Source separation is just one of many techniques that can be used as a preprocessing step to help improve the accuracy and reliability of the overall diarization process.

In [None]:
%%time
if enable_stemming:
    # Isolate vocals from the rest of the audio

    return_code = os.system(
        f'python3 -m demucs.separate -n htdemucs --two-stems=vocals "{audio_path}" -o "temp_outputs"'
    )

    if return_code != 0:
        logging.warning("Source splitting failed, using original audio file.")
        vocal_target = audio_path
    else:
        vocal_target = os.path.join(
            "temp_outputs",
            "htdemucs",
            os.path.splitext(os.path.basename(audio_path))[0],
            "vocals.wav",
        )
else:
    vocal_target = audio_path

CPU times: user 21 µs, sys: 0 ns, total: 21 µs
Wall time: 23.8 µs


## Transcribing audio using Whisper and realligning timestamps using Wav2Vec2
---
This code uses two different open-source models to transcribe speech and perform forced alignment on the resulting transcription.

The first model is called OpenAI Whisper, which is a speech recognition model that can transcribe speech with high accuracy. The code loads the whisper model and uses it to transcribe the vocal_target file.

The output of the transcription process is a set of text segments with corresponding timestamps indicating when each segment was spoken.


In [None]:
%%time
if device == "cuda": compute_type = "float16"
# or run on GPU with INT8
# compute_type = "int8_float16"
# or run on CPU with INT8
else: compute_type = "int8"

if batch_size != 0:
    whisper_results, language = transcribe_batched(
        vocal_target,
        language,
        batch_size,
        whisper_model_name,
        compute_type,
        suppress_numerals,
        device,
    )
else:
    whisper_results, language = transcribe(
        vocal_target,
        language,
        whisper_model_name,
        compute_type,
        suppress_numerals,
        device,
    )

config.json:   0%|          | 0.00/2.39k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/340 [00:00<?, ?B/s]

vocabulary.json:   0%|          | 0.00/1.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.48M [00:00<?, ?B/s]

model.bin:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

No language specified, language will be first be detected for each audio file (increases inference time).


100%|█████████████████████████████████████| 16.9M/16.9M [00:01<00:00, 13.9MiB/s]
INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.0.7. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file ../root/.cache/torch/whisperx-vad-segmentation.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.1.0. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.0.1+cu117. Bad things might happen unless you revert torch to 1.x.
Detected language: en (1.00) in first 30s of audio...
Suppressing numeral and symbol tokens: [3, 4, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 502, 568, 628, 805, 945, 1017, 1025, 1266, 1294, 1360, 1386, 1525, 1614, 1649, 1722, 1848, 1958, 2009, 2119, 2217, 2272, 2319, 2331, 2443, 2625, 2803, 2975, 3165, 3279, 3282, 3356, 3405, 3446, 3499, 3552, 3705, 4022, 4060, 4289, 4303, 4436, 4550, 4688, 4702, 4762, 4808, 5080, 5211, 5254, 5285, 5348, 5853, 5867, 5923, 6071, 6074, 6096, 6375, 6494, 6549, 6591, 6641, 6673, 6856, 6866, 6879, 6905, 6976, 7143, 7201, 7271, 7490, 7526, 7546, 7551, 7560, 7562, 7629, 7634, 7668, 7771, 7773, 7911, 7998, 8132, 8227, 8423, 8451, 8465, 8494, 8652, 8794, 8858, 8923, 9012, 9125, 9356, 9413, 9562, 9657, 9714, 9754, 10076,

## Aligning the transcription with the original audio using Wav2Vec2
---
The second model used is called wav2vec2, which is a large-scale neural network that is designed to learn representations of speech that are useful for a variety of speech processing tasks, including speech recognition and alignment.

The code loads the wav2vec2 alignment model and uses it to align the transcription segments with the original audio signal contained in the vocal_target file. This process involves finding the exact timestamps in the audio signal where each segment was spoken and aligning the text accordingly.

By combining the outputs of the two models, the code produces a fully aligned transcription of the speech contained in the vocal_target file. This aligned transcription can be useful for a variety of speech processing tasks, such as speaker diarization, sentiment analysis, and language identification.

If there's no Wav2Vec2 model available for your language, word timestamps generated by whisper will be used instead.

In [None]:
%%time
if language in wav2vec2_langs:
    device = "cuda"
    alignment_model, metadata = whisperx.load_align_model(
        language_code=language, device=device
    )
    result_aligned = whisperx.align(
        whisper_results, alignment_model, metadata, vocal_target, device
    )
    word_timestamps = filter_missing_timestamps(result_aligned["word_segments"])

    # clear gpu vram
    del alignment_model
    torch.cuda.empty_cache()
else:
    assert batch_size == 0, (  # TODO: add a better check for word timestamps existence
        f"Unsupported language: {language}, use --batch_size to 0"
        " to generate word timestamps using whisper directly and fix this error."
    )
    word_timestamps = []
    for segment in whisper_results:
        for word in segment["words"]:
            word_timestamps.append({"word": word[2], "start": word[0], "end": word[1]})

Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth" to /root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth
100%|██████████| 360M/360M [00:03<00:00, 114MB/s]


CPU times: user 4.15 s, sys: 1.67 s, total: 5.82 s
Wall time: 9.78 s


## Convert audio to mono for NeMo combatibility

In [None]:
%%time
sound = AudioSegment.from_file(vocal_target).set_channels(1)
ROOT = os.getcwd()
temp_path = os.path.join(ROOT, "temp_outputs")
os.makedirs(temp_path, exist_ok=True)
sound.export(os.path.join(temp_path, "mono_file.wav"), format="wav")

CPU times: user 49.1 ms, sys: 27.7 ms, total: 76.8 ms
Wall time: 430 ms


<_io.BufferedRandom name='/content/temp_outputs/mono_file.wav'>

## Speaker Diarization using NeMo MSDD Model
---
This code uses a model called Nvidia NeMo MSDD (Multi-scale Diarization Decoder) to perform speaker diarization on an audio signal. Speaker diarization is the process of separating an audio signal into different segments based on who is speaking at any given time.

In [None]:
%%time
# Initialize NeMo MSDD diarization model
# DOMAIN_TYPE: can be meeting, telephonic, or general based on domain type of the audio file
msdd_model = NeuralDiarizer(cfg=create_config(temp_path, DOMAIN_TYPE="telephonic")).to("cuda")
msdd_model.diarize()

del msdd_model
torch.cuda.empty_cache()

[NeMo I 2024-01-24 20:10:39 msdd_models:1092] Loading pretrained diar_msdd_telephonic model from NGC
[NeMo I 2024-01-24 20:10:39 cloud:68] Downloading from: https://api.ngc.nvidia.com/v2/models/nvidia/nemo/diar_msdd_telephonic/versions/1.0.1/files/diar_msdd_telephonic.nemo to /root/.cache/torch/NeMo/NeMo_1.21.0/diar_msdd_telephonic/3c3697a0a46f945574fa407149975a13/diar_msdd_telephonic.nemo
[NeMo I 2024-01-24 20:11:14 common:913] Instantiating model from pre-trained checkpoint


[NeMo W 2024-01-24 20:11:15 modelPT:161] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: null
    emb_dir: null
    sample_rate: 16000
    num_spks: 2
    soft_label_thres: 0.5
    labels: null
    batch_size: 15
    emb_batch_size: 0
    shuffle: true
    
[NeMo W 2024-01-24 20:11:15 modelPT:168] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). 
    Validation config : 
    manifest_filepath: null
    emb_dir: null
    sample_rate: 16000
    num_spks: 2
    soft_label_thres: 0.5
    labels: null
    batch_size: 15
    emb_batch_size: 0
    shuffle: false
    
[NeMo W 2024-01-24 20:11:15 modelPT:174] Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple

[NeMo I 2024-01-24 20:11:15 features:289] PADDING: 16
[NeMo I 2024-01-24 20:11:15 features:289] PADDING: 16
[NeMo I 2024-01-24 20:11:16 save_restore_connector:249] Model EncDecDiarLabelModel was successfully restored from /root/.cache/torch/NeMo/NeMo_1.21.0/diar_msdd_telephonic/3c3697a0a46f945574fa407149975a13/diar_msdd_telephonic.nemo.
[NeMo I 2024-01-24 20:11:16 features:289] PADDING: 16
[NeMo I 2024-01-24 20:11:17 clustering_diarizer:127] Loading pretrained vad_multilingual_marblenet model from NGC
[NeMo I 2024-01-24 20:11:17 cloud:68] Downloading from: https://api.ngc.nvidia.com/v2/models/nvidia/nemo/vad_multilingual_marblenet/versions/1.10.0/files/vad_multilingual_marblenet.nemo to /root/.cache/torch/NeMo/NeMo_1.21.0/vad_multilingual_marblenet/670f425c7f186060b7a7268ba6dfacb2/vad_multilingual_marblenet.nemo
[NeMo I 2024-01-24 20:11:18 common:913] Instantiating model from pre-trained checkpoint


[NeMo W 2024-01-24 20:11:18 modelPT:161] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: /manifests/ami_train_0.63.json,/manifests/freesound_background_train.json,/manifests/freesound_laughter_train.json,/manifests/fisher_2004_background.json,/manifests/fisher_2004_speech_sampled.json,/manifests/google_train_manifest.json,/manifests/icsi_all_0.63.json,/manifests/musan_freesound_train.json,/manifests/musan_music_train.json,/manifests/musan_soundbible_train.json,/manifests/mandarin_train_sample.json,/manifests/german_train_sample.json,/manifests/spanish_train_sample.json,/manifests/french_train_sample.json,/manifests/russian_train_sample.json
    sample_rate: 16000
    labels:
    - background
    - speech
    batch_size: 256
    shuffle: true
    is_tarred: false
    tarred_audio_filepaths: null
    tarred_shard_strategy: sca

[NeMo I 2024-01-24 20:11:18 features:289] PADDING: 16
[NeMo I 2024-01-24 20:11:18 save_restore_connector:249] Model EncDecClassificationModel was successfully restored from /root/.cache/torch/NeMo/NeMo_1.21.0/vad_multilingual_marblenet/670f425c7f186060b7a7268ba6dfacb2/vad_multilingual_marblenet.nemo.
[NeMo I 2024-01-24 20:11:18 msdd_models:864] Multiscale Weights: [1, 1, 1, 1, 1]
[NeMo I 2024-01-24 20:11:18 msdd_models:865] Clustering Parameters: {
        "oracle_num_speakers": false,
        "max_num_speakers": 8,
        "enhanced_count_thres": 80,
        "max_rp_threshold": 0.25,
        "sparse_search_volume": 30,
        "maj_vote_spk_count": false,
        "chunk_cluster_count": 50,
        "embeddings_per_chunk": 10000
    }
[NeMo I 2024-01-24 20:11:18 speaker_utils:93] Number of files to diarize: 1
[NeMo I 2024-01-24 20:11:18 clustering_diarizer:309] Split long audio file to avoid CUDA memory issue


splitting manifest: 100%|██████████| 1/1 [00:10<00:00, 10.53s/it]

[NeMo I 2024-01-24 20:11:28 classification_models:273] Perform streaming frame-level VAD
[NeMo I 2024-01-24 20:11:28 collections:301] Filtered duration for loading collection is  0.00 hours.
[NeMo I 2024-01-24 20:11:28 collections:302] Dataset loaded with 2 items, total duration of  0.02 hours.
[NeMo I 2024-01-24 20:11:28 collections:304] # 2 files loaded accounting to # 1 labels



vad: 100%|██████████| 2/2 [00:00<00:00,  3.28it/s]

[NeMo I 2024-01-24 20:11:29 clustering_diarizer:250] Generating predictions with overlapping input segments



                                                               

[NeMo I 2024-01-24 20:11:30 clustering_diarizer:262] Converting frame level prediction to speech/no-speech segment in start and end times format.


creating speech segments: 100%|██████████| 1/1 [00:00<00:00,  9.63it/s]

[NeMo I 2024-01-24 20:11:30 clustering_diarizer:287] Subsegmentation for embedding extraction: scale0, /content/temp_outputs/speaker_outputs/subsegments_scale0.json
[NeMo I 2024-01-24 20:11:30 clustering_diarizer:343] Extracting embeddings for Diarization
[NeMo I 2024-01-24 20:11:30 collections:301] Filtered duration for loading collection is  0.00 hours.
[NeMo I 2024-01-24 20:11:30 collections:302] Dataset loaded with 59 items, total duration of  0.02 hours.
[NeMo I 2024-01-24 20:11:30 collections:304] # 59 files loaded accounting to # 1 labels



[1/5] extract embeddings: 100%|██████████| 1/1 [00:00<00:00,  4.45it/s]

[NeMo I 2024-01-24 20:11:30 clustering_diarizer:389] Saved embedding files to /content/temp_outputs/speaker_outputs/embeddings
[NeMo I 2024-01-24 20:11:30 clustering_diarizer:287] Subsegmentation for embedding extraction: scale1, /content/temp_outputs/speaker_outputs/subsegments_scale1.json
[NeMo I 2024-01-24 20:11:30 clustering_diarizer:343] Extracting embeddings for Diarization
[NeMo I 2024-01-24 20:11:30 collections:301] Filtered duration for loading collection is  0.00 hours.
[NeMo I 2024-01-24 20:11:30 collections:302] Dataset loaded with 73 items, total duration of  0.02 hours.
[NeMo I 2024-01-24 20:11:30 collections:304] # 73 files loaded accounting to # 1 labels



[2/5] extract embeddings: 100%|██████████| 2/2 [00:00<00:00,  8.58it/s]

[NeMo I 2024-01-24 20:11:30 clustering_diarizer:389] Saved embedding files to /content/temp_outputs/speaker_outputs/embeddings
[NeMo I 2024-01-24 20:11:30 clustering_diarizer:287] Subsegmentation for embedding extraction: scale2, /content/temp_outputs/speaker_outputs/subsegments_scale2.json
[NeMo I 2024-01-24 20:11:31 clustering_diarizer:343] Extracting embeddings for Diarization
[NeMo I 2024-01-24 20:11:31 collections:301] Filtered duration for loading collection is  0.00 hours.
[NeMo I 2024-01-24 20:11:31 collections:302] Dataset loaded with 87 items, total duration of  0.02 hours.
[NeMo I 2024-01-24 20:11:31 collections:304] # 87 files loaded accounting to # 1 labels



[3/5] extract embeddings: 100%|██████████| 2/2 [00:00<00:00,  8.29it/s]

[NeMo I 2024-01-24 20:11:31 clustering_diarizer:389] Saved embedding files to /content/temp_outputs/speaker_outputs/embeddings
[NeMo I 2024-01-24 20:11:31 clustering_diarizer:287] Subsegmentation for embedding extraction: scale3, /content/temp_outputs/speaker_outputs/subsegments_scale3.json
[NeMo I 2024-01-24 20:11:31 clustering_diarizer:343] Extracting embeddings for Diarization
[NeMo I 2024-01-24 20:11:31 collections:301] Filtered duration for loading collection is  0.00 hours.
[NeMo I 2024-01-24 20:11:31 collections:302] Dataset loaded with 123 items, total duration of  0.02 hours.
[NeMo I 2024-01-24 20:11:31 collections:304] # 123 files loaded accounting to # 1 labels



[4/5] extract embeddings: 100%|██████████| 2/2 [00:00<00:00,  7.11it/s]

[NeMo I 2024-01-24 20:11:31 clustering_diarizer:389] Saved embedding files to /content/temp_outputs/speaker_outputs/embeddings
[NeMo I 2024-01-24 20:11:31 clustering_diarizer:287] Subsegmentation for embedding extraction: scale4, /content/temp_outputs/speaker_outputs/subsegments_scale4.json
[NeMo I 2024-01-24 20:11:31 clustering_diarizer:343] Extracting embeddings for Diarization
[NeMo I 2024-01-24 20:11:31 collections:301] Filtered duration for loading collection is  0.00 hours.
[NeMo I 2024-01-24 20:11:31 collections:302] Dataset loaded with 189 items, total duration of  0.03 hours.
[NeMo I 2024-01-24 20:11:31 collections:304] # 189 files loaded accounting to # 1 labels



[5/5] extract embeddings: 100%|██████████| 3/3 [00:00<00:00,  6.22it/s]

[NeMo I 2024-01-24 20:11:32 clustering_diarizer:389] Saved embedding files to /content/temp_outputs/speaker_outputs/embeddings



clustering: 100%|██████████| 1/1 [00:01<00:00,  1.08s/it]

[NeMo I 2024-01-24 20:11:33 clustering_diarizer:464] Outputs are saved in /content/temp_outputs directory



[NeMo W 2024-01-24 20:11:33 der:185] Check if each ground truth RTTMs were present in the provided manifest file. Skipping calculation of Diariazation Error Rate


[NeMo I 2024-01-24 20:11:33 msdd_models:960] Loading embedding pickle file of scale:0 at /content/temp_outputs/speaker_outputs/embeddings/subsegments_scale0_embeddings.pkl
[NeMo I 2024-01-24 20:11:33 msdd_models:960] Loading embedding pickle file of scale:1 at /content/temp_outputs/speaker_outputs/embeddings/subsegments_scale1_embeddings.pkl
[NeMo I 2024-01-24 20:11:33 msdd_models:960] Loading embedding pickle file of scale:2 at /content/temp_outputs/speaker_outputs/embeddings/subsegments_scale2_embeddings.pkl
[NeMo I 2024-01-24 20:11:33 msdd_models:960] Loading embedding pickle file of scale:3 at /content/temp_outputs/speaker_outputs/embeddings/subsegments_scale3_embeddings.pkl
[NeMo I 2024-01-24 20:11:33 msdd_models:960] Loading embedding pickle file of scale:4 at /content/temp_outputs/speaker_outputs/embeddings/subsegments_scale4_embeddings.pkl
[NeMo I 2024-01-24 20:11:33 msdd_models:938] Loading cluster label file from /content/temp_outputs/speaker_outputs/subsegments_scale4_cluste

100%|██████████| 1/1 [00:00<00:00, 24.77it/s]

[NeMo I 2024-01-24 20:11:33 msdd_models:1403]      [Threshold: 0.7000] [use_clus_as_main=False] [diar_window=50]
[NeMo I 2024-01-24 20:11:33 speaker_utils:93] Number of files to diarize: 1
[NeMo I 2024-01-24 20:11:33 speaker_utils:93] Number of files to diarize: 1



[NeMo W 2024-01-24 20:11:33 der:185] Check if each ground truth RTTMs were present in the provided manifest file. Skipping calculation of Diariazation Error Rate


[NeMo I 2024-01-24 20:11:33 speaker_utils:93] Number of files to diarize: 1


[NeMo W 2024-01-24 20:11:33 der:185] Check if each ground truth RTTMs were present in the provided manifest file. Skipping calculation of Diariazation Error Rate


[NeMo I 2024-01-24 20:11:33 speaker_utils:93] Number of files to diarize: 1


[NeMo W 2024-01-24 20:11:33 der:185] Check if each ground truth RTTMs were present in the provided manifest file. Skipping calculation of Diariazation Error Rate


[NeMo I 2024-01-24 20:11:33 msdd_models:1431]   
    
CPU times: user 15.8 s, sys: 1.05 s, total: 16.9 s
Wall time: 54.1 s


## Mapping Speakers to Sentences According to Timestamps

In [None]:
%%time
# Reading timestamps <> Speaker Labels mapping

speaker_ts = []
with open(os.path.join(temp_path, "pred_rttms", "mono_file.rttm"), "r") as f:
    lines = f.readlines()
    for line in lines:
        line_list = line.split(" ")
        s = int(float(line_list[5]) * 1000)
        e = s + int(float(line_list[8]) * 1000)
        speaker_ts.append([s, e, int(line_list[11].split("_")[-1])])

wsm = get_words_speaker_mapping(word_timestamps, speaker_ts, "start")

CPU times: user 1.28 ms, sys: 767 µs, total: 2.05 ms
Wall time: 1.89 ms


## Realigning Speech segments using Punctuation
---

This code provides a method for disambiguating speaker labels in cases where a sentence is split between two different speakers. It uses punctuation markings to determine the dominant speaker for each sentence in the transcription.

```
Speaker A: It's got to come from somewhere else. Yeah, that one's also fun because you know the lows are
Speaker B: going to suck, right? So it's actually it hits you on both sides.
```

For example, if a sentence is split between two speakers, the code takes the mode of speaker labels for each word in the sentence, and uses that speaker label for the whole sentence. This can help to improve the accuracy of speaker diarization, especially in cases where the Whisper model may not take fine utterances like "hmm" and "yeah" into account, but the Diarization Model (Nemo) may include them, leading to inconsistent results.

The code also handles cases where one speaker is giving a monologue while other speakers are making occasional comments in the background. It ignores the comments and assigns the entire monologue to the speaker who is speaking the majority of the time. This provides a robust and reliable method for realigning speech segments to their respective speakers based on punctuation in the transcription.

In [None]:
%%time
if language in punct_model_langs:
    # restoring punctuation in the transcript to help realign the sentences
    punct_model = PunctuationModel(model="kredor/punctuate-all")

    words_list = list(map(lambda x: x["word"], wsm))

    labled_words = punct_model.predict(words_list)

    ending_puncts = ".?!"
    model_puncts = ".,;:!?"

    # We don't want to punctuate U.S.A. with a period. Right?
    is_acronym = lambda x: re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", x)

    for word_dict, labeled_tuple in zip(wsm, labled_words):
        word = word_dict["word"]
        if (
            word
            and labeled_tuple[1] in ending_puncts
            and (word[-1] not in model_puncts or is_acronym(word))
        ):
            word += labeled_tuple[1]
            if word.endswith(".."):
                word = word.rstrip(".")
            word_dict["word"] = word

else:
    logging.warning(
        f"Punctuation restoration is not available for {language} language. Using the original punctuation."
    )

wsm = get_realigned_ws_mapping_with_punctuation(wsm)
ssm = get_sentences_speaker_mapping(wsm, speaker_ts)

config.json:   0%|          | 0.00/914 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/447 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

CPU times: user 5.49 s, sys: 3.92 s, total: 9.41 s
Wall time: 22.1 s


In [None]:
css = '''
        <style>
            .highlighted {
                font-weight: bold;
                font-size: 140%;
                margin-bottom: 20px;
                padding: 10px
            }
            .redstyle {
                color:red
            }
            .greenstyle {
                color:green
            }
        </style>
      '''

from IPython.display import display, HTML
for line in lines:
    line = line.replace("Speaker 0", "<span class='redstyle'>Speaker 0</span>")
    line = line.replace("Speaker 1", "<span class='greenstyle'>Speaker 1</span>")
    display(HTML(f'{css} <p class="highlighted">{line}</p>'))