In [17]:
import logging

import pandas as pd
import torch

import nemo.collections.asr as nemo_asr
from ctcdecode import CTCBeamDecoder
from nemo.collections.nlp.models import PunctuationCapitalizationModel
from pydub import AudioSegment
from pydub.silence import detect_silence
from pyannote.audio import Pipeline
import shutil

logging.getLogger("nemo_logger").setLevel(logging.ERROR)
asr_logger = logging.getLogger("asr")
asr_logger.setLevel(logging.INFO)


In [33]:
dia_model_name = "pyannote/speaker-diarization@2022.07"
asr_model_name = "stt_en_conformer_ctc_small"  #'QuartzNet15x5Base-En'
punct_model_name = "punctuation_en_bert"

dia_model = Pipeline.from_pretrained(dia_model_name)
asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=asr_model_name)
punct_model = punct_model = PunctuationCapitalizationModel.from_pretrained(punct_model_name)

Using eos_token, but it is not set yet.
Using bos_token, but it is not set yet.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used whe

In [4]:
use_gpu = True
device = 0 if torch.cuda.is_available() else -1

pause_threshold = 1  # RE: collapsing diarised segments
batch_size = 4
offset = -0.18  # calibration offset for timestamps: 180 ms

# load models
vocab = asr_model.decoder.vocabulary
vocab.append("_")
decoder = CTCBeamDecoder(
    vocab,
    beam_width=1,
    blank_id=vocab.index("_"),
    log_probs_input=True,
)
time_stride = 1 / asr_model.cfg.preprocessor.window_size  # duration of model timesteps

# from_disk(model_path)
time_pad = 1
# huge possible max audio if model is Quartznet; maximise where possible to limit segmentation transcription error
second_max_audio = 120 if asr_model_name == "QuartzNet15x5Base-En" else 4
round_value = 3


In [5]:
def _resample_normalize_audio(in_file, out_file, sample_rate=16000):
    # upsample/normalize audio to 16khz WAV
    # via https://github.com/NVIDIA/NeMo/blob/main/tutorials/tools/CTC_Segmentation_Tutorial.ipynb
    if not os.path.exists(in_file):
        raise ValueError(f"{in_file} not found")
    if out_file is None:
        out_file = in_file.replace(os.path.splitext(in_file)[-1], f"_{sample_rate}.wav")

    os.system(
        f"ffmpeg -i {in_file} -acodec pcm_s16le -ac 1 -af aresample=resampler=soxr -ar {sample_rate} {out_file} -y"
    )
    return out_file

def _split_stereo_audio(in_file, out_dir):
    # into left/right channel wavs
    in_file = Path(in_file)
    out_dir = Path(out_dir)
    assert in_file.exists()
    assert out_dir.exists()

    # format output files
    left_channel = out_dir / f"{Path(in_file).stem}_left.wav"
    right_channel = out_dir / f"{Path(in_file).stem}_right.wav"

    # split, export
    audio_segment = AudioSegment.from_file(in_file)
    monos = audio_segment.split_to_mono()
    assert len(monos) == 2  # cap support at stereo audio
    monos[0].export(left_channel, format="wav")
    monos[1].export(right_channel, format="wav")

    return left_channel, right_channel

In [37]:
def _diarize_mono_audio(in_file):
    diarization_raw = dia_model(str(in_file))
    diarized_segments = (
        pd.DataFrame(
            [
                {"start": turn.start, "end": turn.end, "speaker": speaker}
                for turn, _, speaker in diarization_raw.itertracks(yield_label=True)
            ]
        )
        # shift speaker attribution > mark/collapse consecutive speaker segments
        .assign(segment_marker=lambda x: x.speaker.shift(1)).assign(
            segment_marker=lambda x: x.segment_marker != x.speaker
        )
        .assign(segment_marker=lambda x: pd.Series.cumsum(x.segment_marker))
    )

    diarized_segments = (
        diarized_segments
        # groupby/aggregate shifted, collapse consecutive speaker sequences
        .groupby("segment_marker")
        .agg(
            {
                "speaker": "first",
                "start": "first",
                "end": "last",
                "segment_marker": "count",
            }
        )
        .rename(
            mapper={"segment_marker": "segment_marker_count"},
            axis="columns",
            inplace=False,
        )
        .assign(segment_len=lambda x: x.end - x.start)
        # reconcile very short segments with pre/proceeding segment? merging strategy?
        #             .query('segment_len >= 1')
        .reset_index(drop=True)
    )

    return diarized_segments

x = _diarize_mono_audio('../output/temp_dir/methadone-stigma.wav')

In [41]:
# what to do with these tiny segments?
(x
.assign(start_min=lambda x: x.start/60)
.assign(end_min=lambda x: x.end/60)
.query('segment_len < 5')
)

Unnamed: 0,speaker,start,end,segment_marker_count,segment_len,start_min,end_min
2,SPEAKER_01,71.254688,74.950312,3,3.695625,1.187578,1.249172
5,SPEAKER_00,212.346563,212.633438,1,0.286875,3.539109,3.543891
6,SPEAKER_01,212.633438,215.215312,2,2.581875,3.543891,3.586922
8,SPEAKER_01,264.237188,269.114063,1,4.876875,4.403953,4.485234
14,SPEAKER_01,415.403438,420.027188,2,4.62375,6.923391,7.000453
18,SPEAKER_02,495.441563,495.492188,1,0.050625,8.257359,8.258203
23,SPEAKER_00,564.764063,565.641563,1,0.8775,9.412734,9.427359


In [7]:
def _format_word_timestamps(asr_output, chunk_offset):
    preds = asr_output.y_sequence.tolist()  # some funky formatting
    probs_seq = torch.FloatTensor([preds])  # some funky formatting
    beam_result, beam_scores, timesteps, out_seq_len = decoder.decode(probs_seq)
    lens = out_seq_len[0][0]
    timesteps = timesteps[0][0]

    result = []

    if len(timesteps) == 0:
        return result

    start = (timesteps[0] - TIME_PAD) * time_stride + chunk_offset
    end = (timesteps[0] + TIME_PAD * 2) * time_stride + chunk_offset

    token_prev = vocab[int(beam_result[0][0][0])]
    word = token_prev

    for n in range(1, lens):
        token = vocab[int(beam_result[0][0][n])]

        if token[0] == "#":
            # merge subwords
            word = word + token[2:]

        elif token[0] == "-" or token_prev[0] == "-":
            word = word + token

        else:
            word = word.replace("▁", "").replace("_", "")  # remove weird token

            result_word = {
                "startTime": int(start) / 1000,
                "endTime": int(end) / 1000,
                "word": word,
            }
            result.append(result_word)

            start = (timesteps[n] - TIME_PAD) * time_stride + chunk_offset
            word = token

        end = (timesteps[n] + TIME_PAD * 2) * time_stride + chunk_offset
        token_prev = token

    # add last word
    word = word.replace("▁", "").replace("_", "")

    result_word = {
        "startTime": int(start) / 1000,
        "endTime": int(end) / 1000,
        "word": word,
    }
    result.append(result_word)
    return result

In [None]:
# def _gcp_format_single_utterance(time_formatted_words_single, channel_tag=None):
#     # GCP STT consistency etc. why are we standardising on this again?
#     string_formatted_word_stamps = []
#     for e in time_formatted_words_single:
#         temp = deepcopy(e)
#         temp["startTime"] = f"{e['startTime']}s"
#         temp["endTime"] = f"{e['endTime']}s"
#         string_formatted_word_stamps.append(temp)

#     return {
#         "alternatives": [
#             {
#                 "transcript": " ".join(e["word"] for e in time_formatted_words_single),
#                 "words": string_formatted_word_stamps,
#             }
#         ],
#         "speakerTag": time_formatted_words_single[0]["speakerTag"],
#         "channelTag": "None" if not channel_tag else channel_tag,
#         "languageCode": "en",
#     }


# def _gcp_format_aggregate_transcript(time_formatted_words_all):
#     # GCP STT consistency etc.
#     transcript_all = [
#         " ".join(e["word"] for e in segment_transcript)
#         for segment_transcript in time_formatted_words_all
#     ]

#     return " ".join(transcript_all)


# def _gcp_format_channel_seperated_transcript_objects(
#     gcp_formatted_left_res, gcp_formatted_right_res
# ):
#     # merge, sort individual left/right transcripts
#     merged_utterances = pd.concat(
#         [
#             format_utterances_df(gcp_formatted_left_res),
#             format_utterances_df(gcp_formatted_right_res),
#         ]
#     ).sort_values(by=["startTime", "endTime"])

#     # use any/left metadata as base (should be the same file right?)
#     merged_metadata = {
#         k: v for k, v in gcp_formatted_left_res["metadata"].items() if k != "transcript"
#     }
#     merged_transcript = " ".join(merged_utterances.transcript.tolist())
#     merged_metadata["transcript"] = merged_transcript

#     return {
#         "metadata": merged_metadata,
#         "streaming_outputs": (
#             merged_utterances.pipe(
#                 lambda x: x[
#                     [
#                         "alternatives",
#                         "speakerTag",
#                         "channelTag",
#                         "languageCode",
#                     ]
#                 ]
#             ).to_dict(orient="records")
#         ),
#     }


In [8]:
def _naively_segment_utterances(record):
    # apply naive splitting
    n_chunks = int((record.end - record.start) // second_max_audio) + 1
    chunk_len = (record.end - record.start) / n_chunks

    df_temp = pd.DataFrame([record] * n_chunks).reset_index(drop=True)
    df_temp["start"] = df_temp.apply(
        lambda x: x.start + chunk_len * x.name, axis=1
    )  # increase start time
    df_temp["end"] = df_temp.apply(
        lambda x: x.start + chunk_len, axis=1
    )  # increase start time
    df_temp.loc[
        (n_chunks - 1), "end"
    ] = (
        record.end
    )  # adjust end time to actual time (sanity correction in case rounding cuts of audio)
    return df_temp.assign(segment_len=lambda x: x.end - x.start)


def _segment_utterances(audio_segment, record):
    dBFS = audio_segment.dBFS  # audio volume (silence level is relative to volume)
    silences = detect_silence(
        audio_segment, min_silence_len=500, silence_thresh=dBFS - 20
    )  # 0.5 break, time in ms, silence_thresh 20 lower than audio volume

    if len(silences) == 0:
        # no silence detected, lower min_silence_len
        silences = detect_silence(
            audio_segment, min_silence_len=100, silence_thresh=dBFS - 20
        )

        if len(silences) == 0:
            # if still no silences detected after lowering min_silence_len, split naively
            return _naively_segment_utterances(record)

    silences = [[(s[1] - s[0]), s[0] / 1000, s[1] / 1000] for s in silences]  # ms -> s

    df_temp = pd.DataFrame(record).T.reset_index(drop=True)

    # split on longest silence, in middle of silence so no info is lost
    while (len(silences) > 0) & any(df_temp.segment_len > second_max_audio):
        longest_silence = silences.pop(silences.index(max(silences)))
        middle_silence = record.start + (
            longest_silence[1] + (longest_silence[2] - longest_silence[1]) / 2
        )

        record_to_split = df_temp.query(
            f"start < {middle_silence} & end>{middle_silence} & segment_len > {second_max_audio}"
        )
        df_temp = df_temp.drop(record_to_split.index)

        split_utterances = pd.DataFrame(
            [record_to_split.iloc[0], record_to_split.iloc[0]]
        ).reset_index(drop=True)
        split_utterances.loc[0, "end"] = middle_silence
        split_utterances.loc[1, "start"] = middle_silence
        df_temp = (
            pd.concat([df_temp, split_utterances])
            .reset_index(drop=True)
            .assign(segment_len=lambda x: x.end - x.start)
        )

    if any(df_temp.segment_len > second_max_audio):
        # if any segments are still too long, naively split them
        final_df = [df_temp.query(f"segment_len < {second_max_audio}")]
        records_to_split = df_temp.query(f"segment_len > {second_max_audio}")

        for i, record in records_to_split.iterrows():
            final_df.append(_naively_segment_utterances(record))
        return pd.concat(final_df).reset_index(drop=True).sort_values(by=["start"])

    return df_temp


In [21]:
from pathlib import Path
import tempfile
import string

_transcribe_mono('../output/radio_national_podcasts/audio/methadone-stigma.mp3')

ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers
  built with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)
  configuration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --e

TypeError: int() argument must be a string, a bytes-like object or a number, not 'NoneType'

In [20]:
def _transcribe_mono(input_file, single_speaker=False):
    # transcribe a mono wav file
    input_file = Path(input_file)
    asr_logger.info(f"Transcribing: {input_file}..")

    temp_dir = Path('../output/temp_dir')
    shutil.rmtree(str(temp_dir)) if temp_dir.exists() else None
    temp_dir.mkdir(parents=True)

    # with tempfile.TemporaryDirectory() as temp_dir:
    # 1.0 resample, convert to wav
    wav_path = _resample_normalize_audio(
        input_file, str(Path(temp_dir) / f"{Path(input_file).stem}.wav")
    )
    audio_segment = AudioSegment.from_file(wav_path)
    asr_logger.info('Successfully resampled/converted input to WAV')

    # 2.0 diarize input, save diarised segments
    diarized_segments = _diarize_mono_audio(wav_path, single_speaker)
    paths2audio_files = []  # explicitly sequence, RE: sorted() issues

    chunked_diarized_segments = []
    for idx, record in diarized_segments.iterrows():
        if record.segment_len > second_max_audio:
            records = _segment_utterances(
                audio_segment[floor(record.start * 1000) : ceil(record.end * 1000)],
                record,
            )
            chunked_diarized_segments.append(records)
        else:
            chunked_diarized_segments.append(
                pd.DataFrame(record).T.reset_index(drop=True)
            )
    chunked_diarized_segments = pd.concat(chunked_diarized_segments).reset_index(
        drop=True
    )
    asr_logger.info('Successfully diarised input')

    for idx, record in chunked_diarized_segments.iterrows():
        # slice audio per utterance, round start/end to floor/ceil inclusively
        segment_audio = audio_segment[
            floor(record.start * 1000) : ceil(record.end * 1000)
        ]

        # prevent misc output from printing
        segment_audio_res = segment_audio.export(
            Path(temp_dir) / f"chunk_{idx}.wav", format="wav"
        )
        # collect segment audio path
        paths2audio_files.append(str(Path(temp_dir) / f"chunk_{idx}.wav"))
    asr_logger.info('Successfully chunked and saved diarised temp chunks')

    # 3.0 batch transcribe, retrieve transcripts, alignments and logprobs for each utterance
    outputs = asr_model.transcribe(
        paths2audio_files=paths2audio_files,
        batch_size=batch_size,
        return_hypotheses=True,
    )
    asr_logger.info('Successfully processed chunks with ASR model')

    # 4.0 retrieve/format timestamps
    time_formatted_words_all = []
    for idx, record in chunked_diarized_segments.iterrows():
        time_formatted_words = _format_word_timestamps(outputs[idx], record.start)

        # 5.0 apply punctuation to each output
        punctuated_sequence = punct_model.add_punctuation_capitalization(
            [" ".join(e["word"] for e in time_formatted_words)]
        )[0]

        if len(punctuated_sequence.split(" ")) == len(time_formatted_words):
            # easy case, where punctuated output len matches input len; assign directly
            punctuated_sequence_joined = (
                pd.DataFrame(time_formatted_words)
                .assign(word=punctuated_sequence.split(" "))
                .assign(speakerTag=record.speaker)
                .to_dict(orient="records")
            )
            time_formatted_words_all.append(punctuated_sequence_joined)
        else:
            # otherwise.. pad the difference? changes should be limited to immediately proceeding fullstops, commas, question marks
            # https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/nlp/punctuation_and_capitalization.html
            print("Punctuated outputs not the same length as input")

    return time_formatted_words_all


def _transcribe_channel_seperated_audio(input_file):
    with tempfile.TemporaryDirectory() as temp_dir:
        # 1.0 split left/right channels
        left_channel, right_channel = _split_stereo_audio(input_file, temp_dir)

        # 2.0 process as seperate monos
        left_res = _transcribe_mono(left_channel, single_speaker=True)
        right_res = _transcribe_mono(right_channel, single_speaker=True)

    # 3.0 merge outputs
    return _gcp_format_channel_seperated_transcript_objects(left_res, right_res)