In [1]:
import torch
import transformers
import datasets
import torchaudio

In [2]:
model_str = "openai/whisper-tiny"
model = transformers.WhisperForConditionalGeneration.from_pretrained(model_str)

feature_extractor = transformers.WhisperFeatureExtractor.from_pretrained(model_str)

tokenizer = transformers.WhisperTokenizer.from_pretrained(
    model_str,
    language="en", 
    task="transcribe",
)

In [3]:
tokenizer.batch_decode(tokenizer.prefix_tokens)

['<|startoftranscript|>', '<|en|>', '<|transcribe|>', '<|notimestamps|>']

## Load some data

You might need to visit the dataset on HF and agree to the terms of use.
Also, you need to login using huggingface cli to authenticate before loading the dataset

In [4]:
cv_13 = datasets.load_dataset("mozilla-foundation/common_voice_13_0", "en", split="train", streaming=True, use_auth_token=True)

In [5]:
batch_size = 8
batch = []
for sample in cv_13:
    batch.append(sample)
    if len(batch) >= batch_size:
        break

Reading metadata...: 1013968it [00:25, 40092.05it/s]
To support 'mp3' decoding with `torchaudio>=0.12.0`, please install `ffmpeg4` system package. On Google Colab you can run:

	!add-apt-repository -y ppa:jonathonf/ffmpeg-4 && apt update && apt install -y ffmpeg

and restart your runtime. Alternatively, you can downgrade `torchaudio`:

	pip install "torchaudio<0.12"`.

Otherwise 'mp3' files will be decoded with `librosa`.


In [6]:
sample = batch[0]

In [28]:
audio_signal = torchaudio.functional.resample(
    torch.tensor(sample["audio"]["array"]),
    sample["audio"]["sampling_rate"],
    16000,
)
label_str = sample["sentence"]
audio_signal.shape

torch.Size([104064])

In [29]:
features = feature_extractor(audio_signal, sampling_rate=16000, return_tensors="pt")["input_features"]

In [61]:
#forced_prefix_str = "This is a device that has"
forced_prefix_str = "This device"

forced_prefix = tokenizer(text_target=forced_prefix_str, return_tensors="pt")["input_ids"]
forced_prefix = forced_prefix[:, :-1] # remove EOS token, should be done in a better way
forced_prefix_decoded = tokenizer.decode(forced_prefix[0], skip_special_tokens=False)
print(f"{forced_prefix_decoded = } \n")

outputs = model.generate(
    inputs=features,
    decoder_input_ids=forced_prefix,
    max_length=100,
    num_beams=5,
    do_sample=False,
)[0]
outputs_str = tokenizer.decode(outputs, skip_special_tokens=True)

print(f"{label_str = }")
print(f"{forced_prefix_str = }")
print(f"{outputs_str = }")


forced_prefix_decoded = '<|startoftranscript|><|en|><|transcribe|><|notimestamps|>This device' 

label_str = 'This device has a cathode inside an anode wire cage.'
forced_prefix_str = 'This device'
outputs_str = 'This device has a cathode inside an anode wire cage.'


Apparently, model worked well with the forced prefix.
It caught up what part of audio was probably transcribed already and finished the remaining part.

## Batch version

In [62]:
# Trivial experiment

batch_outputs = model.generate(
    inputs=features.repeat(2, 1, 1),
    decoder_input_ids=forced_prefix.repeat(2, 1),
    max_length=100,
    num_beams=5,
    do_sample=False,
)

tokenizer.batch_decode(batch_outputs, skip_special_tokens=True)

['This device has a cathode inside an anode wire cage.',
 'This device has a cathode inside an anode wire cage.']

In [69]:
torch.cat([torch.tensor(tokenizer.pad_token_id).repeat(1, 2), forced_prefix], dim=-1)

tensor([[50257, 50257, 50258, 50259, 50359, 50363,  5723,  4302]])

In [81]:
# Does the model work with padding in prefixes?

# definitely not if the padding is before <fluff> - i.e:
    # <PAD><PAD><PAD><|startoftranscript|><|en|><|transcribe|><|notimestamps|> <FORCED PREFIX>

# Maybe if the padding is after <fluff> - i.e. 
    # <|startoftranscript|><|en|><|transcribe|><|notimestamps|><|PAD|><PAD><PAD> <FORCED PREFIX>
    # but it seems like good idea to use the padding in the training if it's going to be there during .generate()
    # even if it seems from this trivial experiment like the model can handle it during .generate()


batch_outputs_1 = model.generate(
    inputs=features,
    decoder_input_ids=torch.cat([
        torch.tensor(tokenizer.pad_token_id).repeat(1, 2),
        forced_prefix,
    ], dim=-1),
    decoder_attention_mask=torch.cat([
        torch.zeros(1, 2),
        torch.ones_like(forced_prefix),
    ], dim=-1),
    max_length=100,
    num_beams=5,
    do_sample=False,
)

print("PAD BEFORE FLUFF")
print(tokenizer.batch_decode(batch_outputs_1, skip_special_tokens=False))
print()


batch_outputs_2 = model.generate(
    inputs=features,
    decoder_input_ids=torch.cat([
        forced_prefix[:, :4],
        torch.tensor(tokenizer.pad_token_id).repeat(1, 2),
        forced_prefix[:, 4:],
    ], dim=-1),
    decoder_attention_mask=torch.cat([
        torch.ones_like(forced_prefix[:, :4]),
        torch.zeros(1, 2),
        torch.ones_like(forced_prefix[:, 4:]),
    ], dim=-1),
    max_length=100,
    num_beams=5,
    do_sample=False,
)

print("PAD BETWEEN FLUFF AND PREFIX")
print(tokenizer.batch_decode(batch_outputs_2, skip_special_tokens=False))


PAD BEFORE FLUFF
['<|endoftext|><|endoftext|><|startoftranscript|><|en|><|transcribe|><|notimestamps|>This device, is, is, is, is, is, is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is-is']

PAD BETWEEN FLUFF AND PREFIX
['<|startoftranscript|><|en|><|transcribe|><|notimestamps|><|endoftext|><|endoftext|>This device has a cathode inside an anode wire cage.<|endoftext|>']


In [10]:
batch = [sample, sample] # for simplicity

In [103]:
batch_audio_signals = [
    torchaudio.functional.resample(
        torch.tensor(sample["audio"]["array"]),
        sample["audio"]["sampling_rate"],
        16000,
    )
    for sample in batch
]

# turn to mono using librosa library
import librosa
batch_audio_signals = [librosa.to_mono(sig.numpy()) for sig in batch_audio_signals]

batch_features = torch.cat([
    feature_extractor(sig, sampling_rate=16000, return_tensors="pt")["input_features"]
    for sig in batch_audio_signals
])

assert batch_features.ndim == 3
assert batch_features.shape[0] == len(batch)

batch_labels_str = [sample["sentence"] for sample in batch]

#batch_prefixes_str = ["clotho caption : "] + ["audioset keywords : "]
batch_prefixes_str = ["This device", "Hello darkness my old"]

In [104]:
batch_labels_str

['This device has a cathode inside an anode wire cage.',
 'This device has a cathode inside an anode wire cage.']

In [105]:
batch_fluff = tokenizer(
    text_target=[""] * len(batch_prefixes_str),
    return_tensors="pt",
    padding=True,
)
assert (batch_fluff["input_ids"][:, -1] == tokenizer.eos_token_id).all()
batch_fluff_input_ids = batch_fluff["input_ids"][:, :-1] 
batch_fluff_attention_mask = batch_fluff["attention_mask"][:, :-1]

orig_padding_side = tokenizer.padding_side
tokenizer.padding_side = "left"
batch_prefixes = tokenizer(
    text_target=batch_prefixes_str,
    return_tensors="pt",
    add_special_tokens=False,
    padding=True,
)
tokenizer.padding_side = orig_padding_side

batch_prefixes_input_ids = torch.cat([batch_fluff_input_ids, batch_prefixes["input_ids"]], dim=-1)
batch_prefixes_attention_mask = torch.cat([batch_fluff_attention_mask, batch_prefixes["attention_mask"]], dim=-1)


print("FORCED PREFIXES")
for decoded, attn_mask in zip(tokenizer.batch_decode(batch_prefixes_input_ids), batch_prefixes_attention_mask):
    print(decoded)
    print(f"{attn_mask = }")
print()

batch_outputs = model.generate(
    inputs=batch_features,
    decoder_input_ids=batch_prefixes_input_ids,
    decoder_attention_mask=batch_prefixes_attention_mask,
    max_new_tokens=100,
    num_beams=5,
    do_sample=False,
)

batch_outputs_str = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True)

for label_str, forced_prefix_str, output_str in zip(batch_labels_str, batch_prefixes_str, batch_outputs_str):
    print(f"{label_str = }")
    print(f"{forced_prefix_str = }")
    print(f"{output_str = }")
    print()


FORCED PREFIXES
<|startoftranscript|><|en|><|transcribe|><|notimestamps|><|endoftext|><|endoftext|>This device
attn_mask = tensor([1, 1, 1, 1, 0, 0, 1, 1])
<|startoftranscript|><|en|><|transcribe|><|notimestamps|>Hello darkness my old
attn_mask = tensor([1, 1, 1, 1, 1, 1, 1, 1])

label_str = 'This device has a cathode inside an anode wire cage.'
forced_prefix_str = 'This device'
output_str = 'This device has a cathode inside an anode wire cage.'

label_str = 'This device has a cathode inside an anode wire cage.'
forced_prefix_str = 'Hello darkness my old'
output_str = 'Hello darkness my old friend. This device has a cathode inside an anode wire cage.'

