In [1]:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from datasets import load_dataset
import librosa

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# load model + processor
# model_id = "distil-whisper/distil-large-v2"
model_id = "distil-whisper/distil-small.en"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, use_safetensors=True
)
model.to(device)
encoder = model.get_encoder()

processor = AutoProcessor.from_pretrained(model_id)

# load dataset
# dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# sample = dataset[0]["audio"]["array"]
sample,sr = librosa.load("chunk_1.wav", sr=4000) 
print(sample)
print(sample.shape)

# preprocess inputs
input_features = processor(sample, return_tensors="pt").input_features
input_features = input_features.to(device, dtype=torch_dtype)

# forward pass to get encoder hidden states
with torch.no_grad():
    encoder_hidden_states = encoder(input_features).last_hidden_state

  from .autonotebook import tqdm as notebook_tqdm
It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


[0.00064078 0.00545278 0.02468038 ... 0.17281201 0.16281839 0.1476209 ]
(40000,)


In [2]:
print(encoder_hidden_states.shape)  

torch.Size([1, 1500, 768])


In [3]:
#print number of parameters in the model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [4]:
from transformers import Wav2Vec2Model
base_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
print("Number of parameters in the Whisper model: ", count_parameters(model))
print("Number of parameters in the Wav2Vec2 model: ", count_parameters(base_model))

Number of parameters in the Whisper model:  164980224
Number of parameters in the Wav2Vec2 model:  94371712
