In [1]:
from transformers import WhisperModel, WhisperConfig, WhisperFeatureExtractor, WhisperForConditionalGeneration
from transformers import WhisperProcessor, WhisperTokenizer
from datasets import load_dataset
import torch

In [2]:
MODEL_NAME = "openai/whisper-medium"

In [3]:
pretrained_whisper = WhisperModel.from_pretrained(MODEL_NAME)
pretrained_whisper_tokenizer = WhisperTokenizer.from_pretrained(MODEL_NAME)
whisper_for_gen = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME)
whisper_processor = WhisperProcessor.from_pretrained(MODEL_NAME)

In [4]:
# a look at whisper config
config = WhisperConfig()
print(config)


WhisperConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "attention_dropout": 0.0,
  "begin_suppress_tokens": [
    220,
    50256
  ],
  "bos_token_id": 50257,
  "d_model": 256,
  "decoder_attention_heads": 4,
  "decoder_ffn_dim": 1536,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 50257,
  "dropout": 0.0,
  "encoder_attention_heads": 4,
  "encoder_ffn_dim": 1536,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 50256,
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "max_source_positions": 1500,
  "max_target_positions": 448,
  "model_type": "whisper",
  "num_hidden_layers": 6,
  "num_mel_bins": 80,
  "pad_token_id": 50256,
  "scale_embedding": false,
  "transformers_version": "4.23.1",
  "use_cache": true,
  "vocab_size": 51865
}



In [5]:
# look at the special tokens in the whisper_tokenizer
# print(pretrained_whisper_tokenizer.all_special_tokens)
print("BOS token", pretrained_whisper_tokenizer.bos_token, "===>", pretrained_whisper_tokenizer.bos_token_id)
print("EOS token", pretrained_whisper_tokenizer.eos_token, "===>", pretrained_whisper_tokenizer.eos_token_id)

BOS token <|endoftext|> ===> 50257
EOS token <|endoftext|> ===> 50257


In [6]:
forced_ids = whisper_for_gen.config.forced_decoder_ids
for idx, token_id in forced_ids:
    token = pretrained_whisper_tokenizer.decode(token_id)
    print("IDs the model is forced to predict at each timestep")
    print(f"timestep {idx}: {token} => {token_id}")

ids_to_suppress = whisper_for_gen.config.suppress_tokens
for sup_id in ids_to_suppress:
    sup_token = pretrained_whisper_tokenizer.decode(sup_id)
    print(f"({sup_token}, {sup_id})", end=" ")

IDs the model is forced to predict at each timestep
timestep 1: <|en|> => 50259
IDs the model is forced to predict at each timestep
timestep 2: <|transcribe|> => 50359
IDs the model is forced to predict at each timestep
timestep 3: <|notimestamps|> => 50363
(", 1) (#, 2) (', 6) ((, 7) (), 8) (*, 9) (+, 10) (-, 12) (/, 14) (:, 25) (;, 26) (<, 27) (=, 28) (>, 29) (@, 31) ([, 58) (\, 59) (], 60) (^, 61) (_, 62) (`, 63) ({, 90) (|, 91) (}, 92) (~, 93) ( -, 359) ( ", 503) ( (, 522) ( [, 542) ( �, 873) (>>, 893) ( >>, 902) (--, 918) ( ', 922) ( ♪, 931) ( --, 1350) ( *, 1853) ( :, 1982) ( /, 2460) ( <, 2627) (「, 3246) (」, 3253) (�, 3268) ( #, 3536) ( ♫, 3846) (♪, 3961) ( ], 4183) ( +, 4667) ( =, 6585) ( -(, 6647) ( ), 7273) ( ♪♪, 9061) ()), 9383) ( @, 10428) ( {, 10929) ( ~, 11938) ( \, 12033) ( >, 12331) ( ;, 12562) ( >>>, 13793) (♫, 14157) ( -[, 14635) ( ((, 15265) ( (", 15618) (『, 16553) (』, 16604) ( |, 18362) ( ^, 18956) (---, 20075) ( 「, 21675) ( ♬, 22520) (♪♪, 26130) ( _, 26161) ( ))), 

In [7]:
pretrained_whisper.config.bos_token_id

50257

In [9]:
silences = torch.zeros(size=(16000,))

# predict all the tokens # donot suppress any token
whisper_for_gen.config.suppress_tokens = []
# donot force the decoder to predict the language token and task token initally
whisper_for_gen.config.forced_decoder_ids = None

input_features = whisper_processor(silences, return_tensors="pt", sampling_rate=16000).input_features
logits = whisper_for_gen.generate(input_features, max_length=448)

decoded = whisper_processor.batch_decode(logits)

print("silence decoded : {}".format(decoded))

silence decoded : ['<|startoftranscript|><|nocaptions|><|endoftext|>']


In [11]:
libri_dataset = load_dataset("librispeech_asr", "clean", split="validation")
print(libri_dataset[0])

Downloading and preparing dataset librispeech_asr/clean to C:/Users/shast/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/cff5df6e7955c80a67f80e27e7e655de71c689e2d2364bece785b972acb37fe7...


Downloading data files:   0%|          | 0/4 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/338M [00:00<?, ?B/s]

In [None]:
# predict all the tokens # donot suppress any token
whisper_for_gen.config.suppress_tokens = []
# donot force the decoder to predict the language token and task token initally
whisper_for_gen.config.forced_decoder_ids = None

input_features = whisper_processor(libri_dataset[0]["audio"]["array"], return_tensors="pt", sampling_rate=16000).input_features
logits = whisper_for_gen.generate(input_features, max_length=448)

decoded = whisper_processor.batch_decode(logits)

print("a datapoint from [librispeech-val-clean] decoded : {}".format(decoded))
print("a datapoint from [librispeech-val-clean] original: {}".format(libri_dataset[0]["text"]))