In [1]:
# Install Whisper from GitHub repository
!pip install dtw-python

import io
import os
import numpy as np
import pandas as pd
import torch
import urllib
import tarfile
import torchaudio
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import matplotlib.ticker as ticker
from scipy.io import wavfile
from tqdm.notebook import tqdm
from dtw import dtw
from whisper.tokenizer import get_tokenizer
from scipy.ndimage import median_filter
import ipywidgets as widgets
from IPython.display import display, HTML


Importing the dtw module. When using in academic works please cite:
  T. Giorgino. Computing and Visualizing Dynamic Time Warping Alignments in R: The dtw Package.
  J. Stat. Soft., doi:10.18637/jss.v031.i07.



In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
pd.options.display.max_rows = 100
pd.options.display.max_colwidth = 1000

In [3]:
import ipywidgets as widgets

languages = {"af_za": "Afrikaans", "am_et": "Amharic", "ar_eg": "Arabic", "as_in": "Assamese", "az_az": "Azerbaijani", "be_by": "Belarusian", "bg_bg": "Bulgarian", "bn_in": "Bengali", "bs_ba": "Bosnian", "ca_es": "Catalan", "cmn_hans_cn": "Chinese", "cs_cz": "Czech", "cy_gb": "Welsh", "da_dk": "Danish", "de_de": "German", "el_gr": "Greek", "en_us": "English", "es_419": "Spanish", "et_ee": "Estonian", "fa_ir": "Persian", "fi_fi": "Finnish", "fil_ph": "Tagalog", "fr_fr": "French", "gl_es": "Galician", "gu_in": "Gujarati", "ha_ng": "Hausa", "he_il": "Hebrew", "hi_in": "Hindi", "hr_hr": "Croatian", "hu_hu": "Hungarian", "hy_am": "Armenian", "id_id": "Indonesian", "is_is": "Icelandic", "it_it": "Italian", "ja_jp": "Japanese", "jv_id": "Javanese", "ka_ge": "Georgian", "kk_kz": "Kazakh", "km_kh": "Khmer", "kn_in": "Kannada", "ko_kr": "Korean", "lb_lu": "Luxembourgish", "ln_cd": "Lingala", "lo_la": "Lao", "lt_lt": "Lithuanian", "lv_lv": "Latvian", "mi_nz": "Maori", "mk_mk": "Macedonian", "ml_in": "Malayalam", "mn_mn": "Mongolian", "mr_in": "Marathi", "ms_my": "Malay", "mt_mt": "Maltese", "my_mm": "Myanmar", "nb_no": "Norwegian", "ne_np": "Nepali", "nl_nl": "Dutch", "oc_fr": "Occitan", "pa_in": "Punjabi", "pl_pl": "Polish", "ps_af": "Pashto", "pt_br": "Portuguese", "ro_ro": "Romanian", "ru_ru": "Russian", "sd_in": "Sindhi", "sk_sk": "Slovak", "sl_si": "Slovenian", "sn_zw": "Shona", "so_so": "Somali", "sr_rs": "Serbian", "sv_se": "Swedish", "sw_ke": "Swahili", "ta_in": "Tamil", "te_in": "Telugu", "tg_tj": "Tajik", "th_th": "Thai", "tr_tr": "Turkish", "uk_ua": "Ukrainian", "ur_pk": "Urdu", "uz_uz": "Uzbek", "vi_vn": "Vietnamese", "yo_ng": "Yoruba"}
# selection = widgets.Dropdown(
#     options=[("Select language", None), ("----------", None)] + sorted([(f"{v} ({k})", k) for k, v in languages.items()]),
#     value="en_us",
#     description='Language:',
#     disabled=False,
# )

lang = "en_us"
language = "English"

# assert lang is not None, "Please select a language"
print(f"Selected language: {language} ({lang})")

Selected language: English (en_us)


In [4]:
def download(url: str, target_path: str):
    with urllib.request.urlopen(url) as source, open(target_path, "wb") as output:
        with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))


class Fleurs(torch.utils.data.Dataset):
    """
    A simple class to wrap Fleurs and subsample a portion of the dataset as needed.
    """
    def __init__(self, lang, split="test", subsample_rate=1, device=DEVICE):
        url = f"https://storage.googleapis.com/xtreme_translations/FLEURS102/{lang}.tar.gz"
        tar_path = os.path.expanduser(f"~/.cache/fleurs/{lang}.tgz")
        os.makedirs(os.path.dirname(tar_path), exist_ok=True)

        if not os.path.exists(tar_path):
            download(url, tar_path)

        all_audio = {}
        with tarfile.open(tar_path, "r:gz") as tar:
            for member in tar.getmembers():
                name = member.name
                if name.endswith(f"{split}.tsv"):
                    labels = pd.read_table(tar.extractfile(member), names=("id", "file_name", "raw_transcription", "transcription", "_", "num_samples", "gender"))

                if f"/{split}/" in name and name.endswith(".wav"):
                    audio_bytes = tar.extractfile(member).read()
                    all_audio[os.path.basename(name)] = wavfile.read(io.BytesIO(audio_bytes))[1]                    

        self.labels = labels.to_dict("records")[::subsample_rate]
        self.all_audio = all_audio
        self.device = device

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, item):
        record = self.labels[item]
        audio = torch.from_numpy(self.all_audio[record["file_name"]].copy())
        text = record["transcription"]
        
        return (audio, text)

In [5]:
dataset = Fleurs(lang, subsample_rate=100)  # subsample 10% of the dataset for a quick demo

In [14]:
from whisper import whisper
whisper.model.MultiHeadAttention.use_sdpa = False
model = whisper.load_model("base")

print(
    f"Model is {'multilingual' if model.is_multilingual else 'English-only'} "
    f"and has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters."
)

options = dict(language=language, beam_size=5, best_of=5)
transcribe_options = dict(task="transcribe", **options)
# translate_options = dict(task="translate", **options)

# references = []
transcriptions = []
# translations = []

for audio, text in tqdm(dataset):
    transcription = model.transcribe(audio, **transcribe_options)["text"]
    # translation = model.transcribe(audio, **translate_options)["text"]
    
    transcriptions.append(transcription)
    # translations.append(translation)
    # references.append(text)

  checkpoint = torch.load(fp, map_location=device)


Model is multilingual and has 71,825,920 parameters.


  0%|          | 0/7 [00:00<?, ?it/s]



In [16]:
data = pd.DataFrame(dict(transcription=transcriptions))
data

Unnamed: 0,transcription
0,"Unfortunately, stunning traffic flow is difficult because driver behavior cannot be predicted with 100% certainty."
1,"Dr. Mala Bala Subramanian 29 was found in Blue Ash, Ohio, a suburb approximately 15 miles north of Cincinnati, lying on the ground beside the road in a t-shirt and underwear in an apparently heavily medicated state."
2,"Goats seemed to have been first domesticated roughly 10,000 years ago in the Zagros Mountains of Iran."
3,The 25 Dunlap Broadside's still known to exist are the oldest surviving copies of the document. The original handwritten copy has not survived.
4,"Elements like calcium and potassium are considered metals, of course there are also metals like silver and gold."
5,"The pH levels level is indicated by the amount of hydrogen, the age and pH, ions in the tested chemical."
6,"Aristotle, philosopher, theorized that everything is made up of a mixture of one or more four elements. They were earth, water, air, and fire."


In [17]:
import string
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import matplotlib.ticker as ticker

from IPython.display import display, HTML
from whisper.tokenizer import get_tokenizer
from dtw import dtw
from scipy.ndimage import median_filter

%matplotlib inline
%config InlineBackend.figure_format = "retina"

In [18]:
AUDIO_SAMPLES_PER_TOKEN = whisper.audio.HOP_LENGTH * 2
AUDIO_TIME_PER_TOKEN = AUDIO_SAMPLES_PER_TOKEN / whisper.audio.SAMPLE_RATE

medfilt_width = 7
qk_scale = 1.0

tokenizer = get_tokenizer(model.is_multilingual, language=languages[lang])

In [19]:
def split_tokens_on_unicode(tokens: torch.Tensor):
    words = []
    word_tokens = []
    current_tokens = []
    
    for token in tokens.tolist():
        current_tokens.append(token)
        decoded = tokenizer.decode_with_timestamps(current_tokens)
        if "\ufffd" not in decoded:
            words.append(decoded)
            word_tokens.append(current_tokens)
            current_tokens = []
    
    return words, word_tokens

In [20]:
def split_tokens_on_spaces(tokens: torch.Tensor):
    subwords, subword_tokens_list = split_tokens_on_unicode(tokens)
    words = []
    word_tokens = []
    
    for subword, subword_tokens in zip(subwords, subword_tokens_list):
        special = subword_tokens[0] >= tokenizer.eot
        with_space = subword.startswith(" ")
        punctuation = subword.strip() in string.punctuation
        if special or with_space or punctuation:
            words.append(subword)
            word_tokens.append(subword_tokens)
        else:
            words[-1] = words[-1] + subword
            word_tokens[-1].extend(subword_tokens)
    
    return words, word_tokens

In [21]:
# install hooks on the cross attention layers to retrieve the attention weights
QKs = [None] * model.dims.n_text_layer

# Define the hook function explicitly
def save_attention_weights(module, input, output, index):
    QKs[index] = output[-1]  # Save the attention weights for each layer

# Register the forward hooks explicitly
for i, block in enumerate(model.decoder.blocks):
    block.cross_attn.register_forward_hook(lambda module, input, output, i=i: save_attention_weights(module, input, output, i))

In [40]:
# for the first 10 examples in the dataset
sampliing = 16000
AUDIO_TIME_PER_TOKEN = AUDIO_SAMPLES_PER_TOKEN / whisper.audio.SAMPLE_RATE
audio_file = "buffers/buffer_0.wav"
audio_full = torch.tensor(whisper.load_audio(audio_file))
transcription = "When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts."


def get_audio_index_at_time(time: float):
    return int(time * whisper.audio.SAMPLE_RATE)

# for (audio, label), transcription in zip(dataset, transcriptions):
print(transcription)

audio_len = len(audio_full) / whisper.audio.SAMPLE_RATE
chunk_len_s = 2

# loop through the audio in chunk_len_s chunks
for t in range(chunk_len_s, int(audio_len) + chunk_len_s, chunk_len_s):
    # simulate the rest of the audio being silence at time 5s
    duration = get_audio_index_at_time(t)
    audio = torch.cat([audio_full[:duration], torch.zeros(max(0, len(audio_full) - duration))])

    duration = len(audio)
    mel = whisper.log_mel_spectrogram(whisper.pad_or_trim(audio))
    tokens = torch.tensor(
        [
            *tokenizer.sot_sequence,
            tokenizer.timestamp_begin,
        ] + tokenizer.encode(transcription) + [
            tokenizer.timestamp_begin + duration // AUDIO_SAMPLES_PER_TOKEN,
            tokenizer.eot,
        ]
    )
    with torch.no_grad():
        logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))

    print(mel.shape, tokens.shape)
    weights = torch.cat(QKs)  # layers * heads * tokens * frames    
    weights = weights[:, :, :, : duration // AUDIO_SAMPLES_PER_TOKEN].cpu()
    weights = median_filter(weights, (1, 1, 1, medfilt_width))
    weights = torch.tensor(weights * qk_scale).softmax(dim=-1)

    w = weights / weights.norm(dim=-2, keepdim=True)
    matrix = w[-6:].mean(axis=(0, 1))

    alignment = dtw(-matrix.double().numpy())

    jumps = np.pad(np.diff(alignment.index1s), (1, 0), constant_values=1).astype(bool)
    jump_times = alignment.index2s[jumps] * AUDIO_TIME_PER_TOKEN
    words, word_tokens = split_tokens_on_spaces(tokens) 

    # display the word-level timestamps in a table
    word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
    begin_times = jump_times[word_boundaries[:-1]]
    end_times = jump_times[word_boundaries[1:]]

    # we only want to display words up to the current time

    stop_ind = np.argmax(end_times[np.where(end_times < t)[0]])

    data = [
        dict(word=word, begin=begin, end=end)
        for word, begin, end in zip(words[:stop_ind], begin_times[:stop_ind], end_times[:stop_ind])
        if not word.startswith("<|") and word.strip() not in ".,!?、。" 
    ]

    display(pd.DataFrame(data))
    display(HTML("<hr>"))


When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts.
2
torch.Size([80, 3000]) torch.Size([27])


Unnamed: 0,word,begin,end
0,Mr,1.08,1.34
1,and,1.34,1.5
2,Mrs,1.5,1.6


4
torch.Size([80, 3000]) torch.Size([27])


Unnamed: 0,word,begin,end
0,Mr,1.1,1.36
1,and,1.38,1.56
2,Mrs,1.56,1.6
3,Dursley,1.62,2.22
4,woke,2.22,2.68
5,up,2.68,2.84
6,on,2.84,2.9
7,the,2.9,3.22
8,dull,3.22,3.26


6
torch.Size([80, 3000]) torch.Size([27])


Unnamed: 0,word,begin,end
0,Mr,1.1,1.34
1,and,1.36,1.52
2,Mrs,1.52,1.6
3,Dursley,1.62,2.22
4,woke,2.22,2.66
5,up,2.66,2.84
6,on,2.84,2.9
7,the,2.9,3.22
8,dull,3.22,3.28
9,gray,3.64,4.16


8
torch.Size([80, 3000]) torch.Size([27])


Unnamed: 0,word,begin,end
0,Mr,1.1,1.34
1,and,1.36,1.52
2,Mrs,1.52,1.6
3,Dursley,1.62,2.22
4,woke,2.22,2.66
5,up,2.66,2.84
6,on,2.84,2.9
7,the,2.9,3.22
8,dull,3.22,3.28
9,gray,3.64,4.16
