In [44]:
#| default_exp transcribe

In [45]:
#| export
from whisperx import load_model, load_audio, load_align_model, align
from whisperx.diarize import assign_word_speakers
import torch
from pathlib import Path
import json
import pandas as pd
from pyannote.audio import Pipeline

In [46]:
from copy import deepcopy

In [47]:
audio_file = "../data/podcasts/lex_ai_stephen_wolfram_1/tmp/audio_formatted.wav"

In [48]:
#| export
# whisperx config
batch_size = 16
compute_type = "float16"
language = "en"
device = "cuda" if torch.cuda.is_available() else "cpu"

In [49]:
#| export
def get_tmp_dir(audio_file): return Path(audio_file).parent

# Transcription

In [50]:
#| export
def whisper_transcribe(audio_file, language="en", batch_size=batch_size, compute_type=compute_type, device="cuda", save=True):
    model = load_model("large-v2", device, language=language, compute_type=compute_type)
    audio = load_audio(audio_file)
    transcript = model.transcribe(audio, language=language, batch_size=batch_size)['segments']
    if save: 
        with open(get_tmp_dir(audio_file)/"transcript-whisper.json", "w") as f: 
            json.dump(transcript, f, ensure_ascii=False, indent=2)
    return transcript

In [51]:
transcript_whisper = whisper_transcribe(audio_file)

Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.0.5. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file ../../../../.cache/torch/whisperx-vad-segmentation.bin`


Model was trained with pyannote.audio 0.0.1, yours is 2.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.0.1. Bad things might happen unless you revert torch to 1.x.


In [52]:
transcript_whisper[10:12]

[{'text': ' You and your son Christopher helped create the alien language in the movie Arrival. So let me ask maybe a bit of a crazy question, but if aliens were to visit us on earth, do you think we would be able to find a common language? Well,',
  'start': 274.21,
  'end': 290.326},
 {'text': " By the time we're saying aliens are visiting us, we've already prejudiced the whole story. Because the concept of an alien actually visiting, so to speak, we already know they're kind of things that make sense to talk about visiting. So we already know they exist in the same kind of physical setup that we do. It's not just radio signals, it's an actual thing that shows up and so on.",
  'start': 290.63,
  'end': 318.203}]

In [53]:
#| export
def whisperx_align(segments, audio_file, language=language, device="cuda", save=True):
    model, metadata = load_align_model(language_code=language, device=device)
    transcript_aligned = align(segments, model, metadata, audio_file, device=device)
    if save:
        with open(get_tmp_dir(audio_file)/"transcript-whisperx.json", "w") as f: 
            json.dump(transcript_aligned, f, ensure_ascii=False, indent=2)
    return transcript_aligned

In [54]:
transcript_whisperx = whisperx_align(deepcopy(transcript_whisper), audio_file)

In [55]:
print(transcript_whisperx.keys())
transcript_whisperx['segments'][10:12]

dict_keys(['segments', 'word_segments'])


[{'start': 92.728,
  'end': 100.255,
  'text': 'We now agreed to talk again, probably multiple times in the near future, so this is round one, and stay tuned for round two soon.',
  'words': [{'word': 'We', 'start': 92.728, 'end': 92.828, 'score': 0.966},
   {'word': 'now', 'start': 92.888, 'end': 93.068, 'score': 0.752},
   {'word': 'agreed', 'start': 93.148, 'end': 93.509, 'score': 0.706},
   {'word': 'to', 'start': 93.569, 'end': 93.629, 'score': 1.0},
   {'word': 'talk', 'start': 93.689, 'end': 93.889, 'score': 0.936},
   {'word': 'again,', 'start': 93.969, 'end': 94.229, 'score': 0.844},
   {'word': 'probably', 'start': 94.47, 'end': 94.85, 'score': 0.84},
   {'word': 'multiple', 'start': 94.91, 'end': 95.27, 'score': 0.866},
   {'word': 'times', 'start': 95.31, 'end': 95.591, 'score': 0.802},
   {'word': 'in', 'start': 95.631, 'end': 95.691, 'score': 0.736},
   {'word': 'the', 'start': 95.711, 'end': 95.771, 'score': 0.955},
   {'word': 'near', 'start': 95.831, 'end': 95.991, 'sc

In [56]:
#| export
def diarization_to_df(diarization):
    df = pd.DataFrame(diarization.itertracks(yield_label=True))
    df['start'] = df[0].apply(lambda x: x.start)
    df['end'] = df[0].apply(lambda x: x.end)
    df.rename(columns={2: "speaker"}, inplace=True)
    return df

def diarize(audio_file, n_speakers, device="cuda", hf_token=None, save=True):
    if hf_token is None: 
        with open(str(Path.home()/".huggingface/token"), "r") as f: hf_token = f.readline()
    pipeline = Pipeline.from_pretrained(
        "pyannote/speaker-diarization@2.1",
        use_auth_token=hf_token
    ).to(torch.device(device))
    diarization = pipeline(audio_file, num_speakers=n_speakers)
    diarization_df = diarization_to_df(diarization)
    if save: 
        diarization_df.to_csv(get_tmp_dir(audio_file)/"diarization.csv")
    return diarization_df

In [57]:
diarization = diarize(audio_file, n_speakers=None)

Model was trained with pyannote.audio 0.0.1, yours is 2.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.0.1. Bad things might happen unless you revert torch to 1.x.


In [58]:
diarization[:5]

Unnamed: 0,0,1,speaker,start,end
0,[ 00:00:00.497 --> 00:00:33.370],HG,SPEAKER_01,0.497812,33.370313
1,[ 00:00:34.534 --> 00:01:01.484],HH,SPEAKER_01,34.534688,61.484063
2,[ 00:01:02.479 --> 00:01:22.425],HI,SPEAKER_01,62.479688,82.425938
3,[ 00:01:23.050 --> 00:01:32.027],HJ,SPEAKER_01,83.050313,92.027812
4,[ 00:01:32.652 --> 00:01:40.465],HK,SPEAKER_01,92.652187,100.465312


In [59]:
#| export
def assign_speakers(diarization, transcript_whisperx):
    diarized_transcript = assign_word_speakers(
        diarization, transcript_whisperx
    )
    return diarized_transcript

In [60]:
transcript_diarized = assign_speakers(diarization, deepcopy(transcript_whisperx))

In [61]:
print(transcript_diarized.keys())
print(len(transcript_diarized['word_segments']))
transcript_diarized['segments'][0]

dict_keys(['segments', 'word_segments'])
32747


{'start': 0.128,
 'end': 16.34,
 'text': ' The following is a conversation with Stephen Wolfram, a computer scientist, mathematician, and theoretical physicist, who is the founder and CEO of Wolfram Research, a company behind Mathematica, Wolfram Alpha, Wolfram Language, and the new Wolfram Physics Project.',
 'words': [{'word': 'The', 'start': 0.128, 'end': 0.208, 'score': 0.947},
  {'word': 'following',
   'start': 0.248,
   'end': 0.568,
   'score': 0.846,
   'speaker': 'SPEAKER_01'},
  {'word': 'is',
   'start': 0.608,
   'end': 0.668,
   'score': 0.785,
   'speaker': 'SPEAKER_01'},
  {'word': 'a',
   'start': 0.709,
   'end': 0.749,
   'score': 0.491,
   'speaker': 'SPEAKER_01'},
  {'word': 'conversation',
   'start': 0.789,
   'end': 1.369,
   'score': 0.917,
   'speaker': 'SPEAKER_01'},
  {'word': 'with',
   'start': 1.429,
   'end': 1.529,
   'score': 0.477,
   'speaker': 'SPEAKER_01'},
  {'word': 'Stephen',
   'start': 1.549,
   'end': 1.849,
   'score': 0.911,
   'speaker': '

In [62]:
[word['word'] for word in transcript_diarized['word_segments'][:10]]

['The',
 'following',
 'is',
 'a',
 'conversation',
 'with',
 'Stephen',
 'Wolfram,',
 'a',
 'computer']

# Postprocessing

In [63]:
#| export
# whisperx has some fields missing, 
# which we added by setting them to their previous value
def add_missing_field(obj, key, prev_value, missing_log=None):
    if not key in obj: 
        obj[key] = prev_value
        if missing_log: missing_log[key].append(obj)
    prev_value = obj[key]
    return obj, prev_value

def find_first_speaker(objects):
    for obj in objects:
        if 'speaker' in obj:
            return obj['speaker']

def add_missing_segment_values(segments):
    prev_speaker = find_first_speaker(segments); prev_start = 0.; prev_end = 0.
    missing_log = { "speaker": [], "start": [], "end": [] }
    for s in segments:
        s, prev_speaker = add_missing_field(s, 'speaker', prev_speaker, missing_log)
        for w in s['words']:
            w, prev_start = add_missing_field(w, 'start', prev_start, missing_log)
            w, prev_end = add_missing_field(w, 'end', prev_end, missing_log)
    return segments, missing_log

def add_missing_word_values(words):
    prev_speaker = find_first_speaker(words); prev_start = 0.; prev_end = 0.
    for w in words:
        w, prev_speaker = add_missing_field(w, 'speaker', prev_speaker)
        w, prev_start = add_missing_field(w, 'start', prev_start)
        w, prev_end = add_missing_field(w, 'end', prev_end)
    return words

In [64]:
transcript_fixed, missing_values = add_missing_segment_values(deepcopy(transcript_diarized['segments']))

In [65]:
transcript_fixed[0].keys()

dict_keys(['start', 'end', 'text', 'words', 'speaker'])

In [66]:
print([word['word'] for word in missing_values['start']]) # missing values are all numeric
print([word['word'] for word in missing_values['speaker']])

['2019,', '$1.', '$10,', '$10', '2001', '30', '30', '1,000', '42,', '10', '150', '1900,', '1931', '1936,', '3785.', '50,000', '14.6', '30', '100', '100', '1915,', '100', '200', '200', '300', '10', '20', '1980,', '81,', '30', '30', '30', '100', '2002,', '1200', '1,200', '300', '30', '3.1415926,', '30,', '30,', '30', '30', '30.', '400', '30,', '1984,', '300', '30', '30', '30', '30,', '30', '$30,000', '30,', '$30,000', '10,000', '30', '30.', '30', '30.', '30', '2007', '30,', '400', '1988.', '6,000', '10', '6,000.', '6,000,', '6,000,', '10', '10', '10.', '8%', '57', '8%,', '8%', '89%', '6,000.', '10,', '10', '10', '10', '300', '12,', '60', '33', '10', '30', '20,', '30,', '50', '300', '50', '30', '1%,', '10,', '20%', '500']
[]


In [67]:
#| export 
def process_transcript(transcript):
    return {
        'segments': add_missing_segment_values(transcript['segments']),
        'words': add_missing_word_values(transcript['word_segments'])
    }

In [68]:
transcript_diarized.keys()

dict_keys(['segments', 'word_segments'])

In [69]:
transcript_processed = process_transcript(deepcopy(transcript_diarized))

In [70]:
print(transcript_processed.keys())

dict_keys(['segments', 'words'])


In [71]:
with open(get_tmp_dir(audio_file)/'transcript.json', "w", encoding='utf8') as f:
    json.dump(transcript_processed, f, ensure_ascii=False, indent=2)

In [72]:
#| export
def transcribe(audio_file, n_speakers, device="cuda", save=True):
    transcript_whisper = whisper_transcribe(audio_file=audio_file, device=device)
    transcript_aligned = whisperx_align(transcript_whisper, audio_file, device=device)
    diarization = diarize(audio_file, n_speakers, device)
    transcript_diarized = assign_speakers(diarization, transcript_aligned)
    transcript_processed = process_transcript(transcript_diarized)
    if save:
        with open(get_tmp_dir(audio_file)/'transcript.json', "w") as f:
            json.dump(transcript_processed, f, ensure_ascii=False, indent=2)
    return transcript_processed

In [73]:
#| hide
from nbdev import nbdev_export
nbdev_export()