In [1]:
#| default_exp transcribe

In [1]:
#| export
from whisperx import load_model, load_audio, load_align_model, align
from whisperx.diarize import assign_word_speakers
import torch
from pathlib import Path
import json
import pandas as pd
from pyannote.audio import Pipeline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from copy import deepcopy

In [8]:
audio_file = "../data/podcast/people_i_admire_104_joy_of_maths(1)/tmp/audio.wav"

In [9]:
#| export
# whisperx config
batch_size = 16
compute_type = "float16"
language = "en"
device = "cuda" if torch.cuda.is_available() else "cpu"

In [10]:
#| export
def get_tmp_dir(audio_file): return Path(audio_file).parent

# Transcription

In [11]:
#| export
def whisper_transcribe(audio_file, language="en", batch_size=batch_size, compute_type=compute_type, device="cuda", save=True):
    model = load_model("large-v2", device, language=language, compute_type=compute_type)
    audio = load_audio(audio_file)
    transcript = model.transcribe(audio, language=language, batch_size=batch_size)['segments']
    if save: 
        with open(get_tmp_dir(audio_file)/"transcript-whisper.json", "w") as f: 
            json.dump(transcript, f, ensure_ascii=False, indent=2)
    return transcript

In [12]:
transcript_whisper = whisper_transcribe(audio_file)

Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.0.5. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file ../../../.cache/torch/whisperx-vad-segmentation.bin`


Model was trained with pyannote.audio 0.0.1, yours is 2.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.0.1. Bad things might happen unless you revert torch to 1.x.


In [13]:
transcript_whisper[10:12]

[{'text': " It reminds me a little bit in our last conversation where you talked about the mathematics of vibration and how some combinations of sounds naturally sound good together. But it does seem like a little bit more of a stretch when you apply it to literature. Do you think that's fair? I think it's perhaps less",
  'start': 250.923,
  'end': 270.565},
 {'text': " on the surface. That's one thing. When you get into looking at how various kinds of poetry or literature are put together, it ceases to feel like a stretch. I'll give you an example of a book. There's a book called The Luminaries by Eleanor Catton. It won the Booker Prize in 2013. That book has a mathematical structure underneath it, which is that every chapter is half the length of the one before.",
  'start': 270.734,
  'end': 294.713}]

In [14]:
#| export
def whisperx_align(segments, audio_file, language=language, device="cuda", save=True):
    model, metadata = load_align_model(language_code=language, device=device)
    transcript_aligned = align(segments, model, metadata, audio_file, device=device)
    if save:
        with open(get_tmp_dir(audio_file)/"transcript-whisperx.json", "w") as f: 
            json.dump(transcript_aligned, f, ensure_ascii=False, indent=2)
    return transcript_aligned

In [15]:
transcript_whisperx = whisperx_align(deepcopy(transcript_whisper), audio_file)

In [16]:
print(transcript_whisperx.keys())
transcript_whisperx['segments'][10:12]

dict_keys(['segments', 'word_segments'])


[{'start': 68.125,
  'end': 75.05,
  'text': ' So, Sarah, you were on the show back in 2021, and that was a conversation that really sticks with me.',
  'words': [{'word': 'So,', 'start': 68.125, 'end': 68.245, 'score': 0.84},
   {'word': 'Sarah,', 'start': 68.265, 'end': 68.465, 'score': 0.65},
   {'word': 'you', 'start': 68.485, 'end': 68.565, 'score': 0.998},
   {'word': 'were', 'start': 68.605, 'end': 68.765, 'score': 0.798},
   {'word': 'on', 'start': 68.886, 'end': 68.966, 'score': 0.835},
   {'word': 'the', 'start': 68.986, 'end': 69.086, 'score': 0.717},
   {'word': 'show', 'start': 69.126, 'end': 69.406, 'score': 0.79},
   {'word': 'back', 'start': 69.706, 'end': 69.966, 'score': 0.983},
   {'word': 'in', 'start': 70.046, 'end': 70.146, 'score': 0.806},
   {'word': '2021,'},
   {'word': 'and', 'start': 70.707, 'end': 71.467, 'score': 0.65},
   {'word': 'that', 'start': 72.208, 'end': 72.328, 'score': 0.912},
   {'word': 'was', 'start': 72.348, 'end': 72.448, 'score': 0.777},
 

In [17]:
#| export
def diarization_to_df(diarization):
    df = pd.DataFrame(diarization.itertracks(yield_label=True))
    df['start'] = df[0].apply(lambda x: x.start)
    df['end'] = df[0].apply(lambda x: x.end)
    df.rename(columns={2: "speaker"}, inplace=True)
    return df

def diarize(audio_file, n_speakers, device="cuda", hf_token=None, save=True):
    if hf_token is None: 
        with open(str(Path.home()/".huggingface/token"), "r") as f: hf_token = f.readline()
    pipeline = Pipeline.from_pretrained(
        "pyannote/speaker-diarization@2.1",
        use_auth_token=hf_token
    ).to(torch.device(device))
    diarization = pipeline(audio_file, num_speakers=n_speakers)
    diarization_df = diarization_to_df(diarization)
    if save: 
        diarization_df.to_csv(get_tmp_dir(audio_file)/"diarization.csv")
    return diarization_df

In [18]:
diarization = diarize(audio_file, n_speakers=None)

Model was trained with pyannote.audio 0.0.1, yours is 2.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.0.1. Bad things might happen unless you revert torch to 1.x.


In [19]:
diarization[:5]

Unnamed: 0,0,1,speaker,start,end
0,[ 00:00:05.644 --> 00:00:17.490],G,SPEAKER_01,5.644687,17.490938
1,[ 00:00:18.132 --> 00:00:28.881],BO,SPEAKER_03,18.132188,28.881563
2,[ 00:00:31.446 --> 00:00:36.002],A,SPEAKER_00,31.446562,36.002813
3,[ 00:00:37.791 --> 00:01:00.505],H,SPEAKER_01,37.791563,60.505313
4,[ 00:01:08.065 --> 00:01:20.063],I,SPEAKER_01,68.065313,80.063438


In [20]:
#| export
def assign_speakers(diarization, transcript_whisperx):
    diarized_transcript = assign_word_speakers(
        diarization, transcript_whisperx
    )
    return diarized_transcript

In [21]:
transcript_diarized = assign_speakers(diarization, deepcopy(transcript_whisperx))

In [22]:
print(transcript_diarized.keys())
print(len(transcript_diarized['word_segments']))
transcript_diarized['segments'][0]

dict_keys(['segments', 'word_segments'])
8956


{'start': 5.688,
 'end': 13.491,
 'text': ' My guest today, Sarah Hart, is the Gresham Professor of Geometry, the first woman to hold that position in its 400-year history.',
 'words': [{'word': 'My',
   'start': 5.688,
   'end': 5.808,
   'score': 0.762,
   'speaker': 'SPEAKER_01'},
  {'word': 'guest',
   'start': 5.828,
   'end': 6.008,
   'score': 0.524,
   'speaker': 'SPEAKER_01'},
  {'word': 'today,',
   'start': 6.028,
   'end': 6.348,
   'score': 0.719,
   'speaker': 'SPEAKER_01'},
  {'word': 'Sarah',
   'start': 6.488,
   'end': 6.788,
   'score': 0.707,
   'speaker': 'SPEAKER_01'},
  {'word': 'Hart,',
   'start': 6.829,
   'end': 7.069,
   'score': 0.814,
   'speaker': 'SPEAKER_01'},
  {'word': 'is',
   'start': 7.349,
   'end': 7.449,
   'score': 0.641,
   'speaker': 'SPEAKER_01'},
  {'word': 'the',
   'start': 7.469,
   'end': 7.549,
   'score': 0.966,
   'speaker': 'SPEAKER_01'},
  {'word': 'Gresham',
   'start': 7.589,
   'end': 7.929,
   'score': 0.721,
   'speaker': 'SPE

In [23]:
[word['word'] for word in transcript_diarized['word_segments'][:10]]

['My',
 'guest',
 'today,',
 'Sarah',
 'Hart,',
 'is',
 'the',
 'Gresham',
 'Professor',
 'of']

# Postprocessing

In [24]:
#| export
# whisperx has some fields missing, 
# which we added by setting them to their previous value
def add_missing_field(obj, key, prev_value, missing_log=None):
    if not key in obj: 
        obj[key] = prev_value
        if missing_log: missing_log[key].append(obj)
    prev_value = obj[key]
    return obj, prev_value

def find_first_speaker(objects):
    for obj in objects:
        if 'speaker' in obj:
            return obj['speaker']

def add_missing_segment_values(segments):
    prev_speaker = find_first_speaker(segments); prev_start = 0.; prev_end = 0.
    missing_log = { "speaker": [], "start": [], "end": [] }
    for s in segments:
        s, prev_speaker = add_missing_field(s, 'speaker', prev_speaker, missing_log)
        for w in s['words']:
            w, prev_start = add_missing_field(w, 'start', prev_start, missing_log)
            w, prev_end = add_missing_field(w, 'end', prev_end, missing_log)
    return segments, missing_log

def add_missing_word_values(words):
    prev_speaker = find_first_speaker(words); prev_start = 0.; prev_end = 0.
    for w in words:
        w, prev_speaker = add_missing_field(w, 'speaker', prev_speaker)
        w, prev_start = add_missing_field(w, 'start', prev_start)
        w, prev_end = add_missing_field(w, 'end', prev_end)
    return words

In [25]:
transcript_fixed, missing_values = add_missing_segment_values(deepcopy(transcript_diarized['segments']))

In [26]:
transcript_fixed[0].keys()

dict_keys(['start', 'end', 'text', 'words', 'speaker'])

In [27]:
print([word['word'] for word in missing_values['start']]) # missing values are all numeric
print([word['word'] for word in missing_values['speaker']])

['2021,', '17', '5', '7', '5.', '5,', '7,', '17,', '17?', '16?', '12', '17', '2013.', '14', '10', '10', '10', '10', '10', '10', '10', '10,', '14', '100', '26', '61.', '273.', '10', '16', '32', '24', '2', '2', '10,', '2', '10', '1,024', '50', '42.', '42.', '2,', '2', '2', '2,', '8.', '2', '2,', '4.', '8.', '4.', '10', '10', '10', '15.', '20', '20', '1597', '2021.', '3,000', '12', '49', '96.']
[]


### Ordering speaker numbers

I've found this was causing confusion, to myslef as well as the LLMs, so here we want to avoid the speaker numbers to appear sequentially.

In [43]:
transcript_diarized.keys()

dict_keys(['segments', 'word_segments'])

In [57]:
unique_speakers = []
for speech in transcript_diarized['segments']:
    if speech['speaker'] not in unique_speakers:
        unique_speakers.append(speech['speaker'])

renamed_speakers = {speaker:'SPEAKER_0'+str(i) for i, speaker in enumerate(unique_speakers)}

for speech in transcript_diarized['segments']:
    speech['speaker'] = renamed_speakers[speech['speaker']]

for word in transcript_diarized['word_segments']:
    word['speaker'] = renamed_speakers[word['speaker']]

for speech in transcript_diarized['segments'][:10]:
    print(speech['speaker'])

SPEAKER_00
SPEAKER_00
SPEAKER_01
SPEAKER_01
SPEAKER_01
SPEAKER_01
SPEAKER_02
SPEAKER_00
SPEAKER_00
SPEAKER_00


In [70]:
#| export 
def order_speakers(transcript_words):

    # cannot use set as we need to preserve the order of which they appear
    unique_speakers = []
    for speech in transcript_words:
        if speech['speaker'] not in unique_speakers:
            unique_speakers.append(speech['speaker'])
    
    rename_mapping = {speaker:'SPEAKER_0'+str(i) for i, speaker in enumerate(unique_speakers)}

    for word in transcript_words:
        word['speaker'] = rename_mapping[word['speaker']]

    return transcript_words

In [71]:
processed_transcript = order_speakers(add_missing_word_values(transcript_diarized['word_segments']))

In [72]:
#| export 
def process_transcript(transcript):
    return order_speakers(
        add_missing_word_values(
            transcript['word_segments'])
    )

In [73]:
transcript_diarized.keys()

dict_keys(['segments', 'word_segments'])

In [74]:
transcript_processed = process_transcript(deepcopy(transcript_diarized))

In [75]:
print(transcript_processed[0])

{'word': 'My', 'start': 5.688, 'end': 5.808, 'score': 0.762, 'speaker': 'SPEAKER_00'}


In [76]:
with open(get_tmp_dir(audio_file)/'transcript.json', "w", encoding='utf8') as f:
    json.dump(transcript_processed, f, ensure_ascii=False, indent=2)

In [77]:
#| export
def transcribe(audio_file, n_speakers, device="cuda", save=True):
    transcript_whisper = whisper_transcribe(audio_file=audio_file, device=device)
    transcript_aligned = whisperx_align(transcript_whisper, audio_file, device=device)
    diarization = diarize(audio_file, n_speakers, device)
    transcript_diarized = assign_speakers(diarization, transcript_aligned)
    transcript_processed = process_transcript(transcript_diarized)
    if save:
        with open(get_tmp_dir(audio_file)/'transcript.json', "w") as f:
            json.dump(transcript_processed, f, ensure_ascii=False, indent=2)
    return transcript_processed

In [79]:
#| hide
from nbdev import nbdev_export
nbdev_export()