reference code:

https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py

https://huggingface.co/openai/whisper-tiny

In [None]:
!wget https://huggingface.co/openai/whisper-tiny/resolve/main/model.safetensors

--2025-12-31 12:40:28--  https://huggingface.co/openai/whisper-tiny/resolve/main/model.safetensors
Resolving huggingface.co (huggingface.co)... 3.167.112.96, 3.167.112.45, 3.167.112.38, ...
Connecting to huggingface.co (huggingface.co)|3.167.112.96|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cas-bridge.xethub.hf.co/xet-bridge-us/63314bb6acb6472115aa55a9/de70b2cacb80b4c0e91f8e8c5c32e8004ed0b10a8cb13d9385b1afeb93ec14cc?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=cas%2F20251231%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20251231T124028Z&X-Amz-Expires=3600&X-Amz-Signature=b54a0f071440e3c2085ed294a8a5cd3e013692971919690999637dc0e394e66b&X-Amz-SignedHeaders=host&X-Xet-Cas-Uid=public&response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.safetensors%3B+filename%3D%22model.safetensors%22%3B&x-id=GetObject&Expires=1767188428&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTO

In [None]:
!unzip /content/harvard.wav.zip # downloaded from here: https://www.kaggle.com/datasets/pavanelisetty/sample-audio-files-for-speech-recognition

Archive:  /content/harvard.wav.zip
  inflating: harvard.wav             


### This is what I want to achieve in my implementation!

In [None]:
import time
import torch
import librosa
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import warnings

warnings.filterwarnings("ignore")

model_id = "openai/whisper-tiny"
processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id)
audio, sr = librosa.load("/content/harvard.wav", sr=None, mono=True)
if sr != 16000:
    audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
    sr = 16000
inputs = processor(audio, sampling_rate=sr, return_tensors="pt")
forced_decoder_ids = processor.get_decoder_prompt_ids(
    language="en",
    task="transcribe"
)

start_time = time.time()
with torch.no_grad():
    predicted_ids = model.generate(inputs.input_features, forced_decoder_ids=forced_decoder_ids)
text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
end_time = time.time()
print(f"Time taken: {end_time-start_time}")
print(text) # It takes ~3 seconds for the transcription of audio.



preprocessor_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

normalizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/151M [00:00<?, ?B/s]

generation_config.json: 0.00B [00:00, ?B/s]

Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of the `task` and `language` flags/config options.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Time taken: 2.2056379318237305
 The stale smell of old beer lingers. It takes heat to bring out the odor. A cold dip restores health in zest. A salt pickle tastes fine with ham. Tacos all pastora are my favorite. A zestful food is the hot cross bun.


### My implementation starts

In [None]:
import math
import time
import warnings
import numpy as np
import torch
import torch.nn as nn
import librosa
from transformers import WhisperProcessor, WhisperTokenizer
from safetensors.torch import load_file
from torch.profiler import profile, ProfilerActivity, record_function


warnings.filterwarnings("ignore")

In [None]:
class MultiHeadAttention(nn.Module):

  def __init__(self, dim, num_heads, is_decoder=False):
    super().__init__()
    self.num_heads = num_heads
    self.head_dim = dim//num_heads
    self.scale = 1.0/math.sqrt(self.head_dim)
    self.is_decoder = is_decoder
    self.k = nn.Linear(dim, dim, bias=False)
    self.q = nn.Linear(dim, dim, bias=True)
    self.v = nn.Linear(dim, dim, bias=True)
    self.out = nn.Linear(dim, dim, bias=True)

  def forward(self, x, context=None):
    B, T_query, C = x.size()
    q = self.q(x).reshape(B, T_query, self.num_heads, self.head_dim).transpose(1, 2)
    kv_input = context if context is not None else x
    B_kv, T_kv, C_kv = kv_input.size()
    k = self.k(kv_input).reshape(B_kv, T_kv, self.num_heads, self.head_dim).transpose(1, 2)
    v = self.v(kv_input).reshape(B_kv, T_kv, self.num_heads, self.head_dim).transpose(1, 2)
    scores = (q@k.transpose(-2, -1))*self.scale
    if context is None and self.is_decoder:
        causal_mask = torch.triu(torch.ones(T_query, T_query, device=x.device), diagonal=1).bool()
        scores = scores.masked_fill(causal_mask, float("-inf"))
    attn_wts = torch.softmax(scores, dim=-1)
    output = attn_wts@v
    output = output.transpose(1, 2).reshape(B, T_query, C)
    output = self.out(output)
    return output

In [None]:
class WhisperEncoderLayer(nn.Module):

  def __init__(self, dim, num_heads, ff_dim):
    super().__init__()
    self.self_attn = MultiHeadAttention(dim, num_heads, is_decoder=False)
    self.self_attn_layer_norm = nn.LayerNorm(dim, elementwise_affine=True)
    self.activation_fn = nn.GELU()
    self.fc1 = nn.Linear(dim, ff_dim)
    self.fc2 = nn.Linear(ff_dim, dim)
    self.final_layer_norm = nn.LayerNorm(dim, elementwise_affine=True)

  def forward(self, x):
    residual = x
    x = self.self_attn_layer_norm(x)
    x = self.self_attn(x) # in the source code there's dropout after this
                          # but when i checked the value of dropout in config (https://huggingface.co/openai/whisper-tiny/blob/main/config.json)
                          # it was 0 - so, decided to not use it here.
    x += residual
    residual = x
    x = self.final_layer_norm(x)
    x = self.fc1(x)
    x = self.activation_fn(x)
    x = self.fc2(x)
    x += residual
    return x

In [None]:
class WhisperEncoder(nn.Module):

  def __init__(self,
               in_channels,
               out_channels,
               kernel_size,
               stride,
               padding,
               num_embedding,
               embedding_dim,
               num_encoder_layers,
               num_heads,
               ff_dim):
    super().__init__()
    self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
    self.activation_fn = nn.GELU()
    self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, 2*stride, padding)
    self.activation_fn = nn.GELU()
    self.embed_positions = nn.Embedding(num_embedding, embedding_dim)
    self.layers = nn.ModuleList([
        WhisperEncoderLayer(embedding_dim, num_heads, ff_dim)
        for _ in range(num_encoder_layers)
    ])
    self.layer_norm = nn.LayerNorm(embedding_dim, elementwise_affine=True)

  def forward(self, x):
    x = self.conv1(x)
    x = self.activation_fn(x)
    x = self.conv2(x)
    x = self.activation_fn(x)
    x = x.permute(0, 2, 1)
    positions = torch.arange(x.size(1), device=x.device).unsqueeze(0)
    x += self.embed_positions(positions)
    for layer in self.layers:
      x = layer(x)
    x = self.layer_norm(x)
    return x

In [None]:
class WhisperDecoderLayer(nn.Module):

  def __init__(self, dim, num_heads, ff_dim):
    super().__init__()
    self.self_attn = MultiHeadAttention(dim, num_heads, is_decoder=True)
    self.activation_fn = nn.GELU()
    self.self_attn_layer_norm = nn.LayerNorm(dim)
    self.encoder_attn = MultiHeadAttention(dim, num_heads, is_decoder=False)
    self.encoder_attn_layer_norm = nn.LayerNorm(dim)
    self.fc1 = nn.Linear(dim, ff_dim)
    self.activation_fn = nn.GELU()
    self.fc2 = nn.Linear(ff_dim, dim)
    self.final_layer_norm = nn.LayerNorm(dim)

  def forward(self, x, encoder_output):
    residual = x
    x = self.self_attn_layer_norm(x)
    x = self.self_attn(x)
    x += residual
    residual = x
    x = self.encoder_attn_layer_norm(x)
    x = self.encoder_attn(x, context=encoder_output)
    x += residual
    residual = x
    x = self.final_layer_norm(x)
    x = self.fc1(x)
    x = self.activation_fn(x)
    x = self.fc2(x)
    x += residual
    return x

In [None]:
class WhisperDecoder(nn.Module):

  def __init__(self, vocab_size, dim, max_positions, num_decoder_layers, num_heads, ff_dim, padding_idx):
    super().__init__()
    self.embed_tokens = nn.Embedding(vocab_size, dim, padding_idx)
    self.embed_positions = nn.Embedding(max_positions, dim)
    self.layers = nn.ModuleList([
        WhisperDecoderLayer(dim, num_heads, ff_dim)
        for _ in range(num_decoder_layers)
    ])
    self.layer_norm = nn.LayerNorm(dim)

  def forward(self, input_ids, encoder_output):
    x = self.embed_tokens(input_ids)
    positions = torch.arange(input_ids.size(1), device=input_ids.device).unsqueeze(0)

    x += self.embed_positions(positions)
    for layer in self.layers:
      x = layer(x, encoder_output)
    x = self.layer_norm(x)
    return x

In [None]:
class WhisperModel(nn.Module):

  def __init__(self):
    super().__init__()
    self.encoder = WhisperEncoder(
            in_channels=80,           # mel spectrogram features (80 mel bins)
            out_channels=384,         # model dimension (d_model)
            kernel_size=3,            # conv kernel size
            stride=1,                 # conv1 stride
            padding=1,                # conv1 padding
            num_embedding=1500,       # max positions for encoder
            embedding_dim=384,        # model dimension
            num_encoder_layers=4,     # number of encoder layers
            num_heads=6,              # attention heads (384/6 = 64 head_dim)
            ff_dim=1536)
    self.decoder = WhisperDecoder(
            vocab_size=51865,         # tokenizer vocabulary size
            dim=384,                  # model dimension
            max_positions=448,        # max positions for decoder
            num_decoder_layers=4,     # number of decoder layers
            num_heads=6,              # attention heads
            ff_dim=1536,              # feed-forward dimension
            padding_idx=50257)        # padding token id)
    self.proj_out = nn.Linear(384, 51865, bias=False)

    self.proj_out.weight = self.decoder.embed_tokens.weight

  # def forward(self, mel_spectrogram, decoder_input_ids):
  #   encoder_output = self.encoder(mel_spectrogram)
  #   decoder_output = self.decoder(decoder_input_ids, encoder_output)
  #   logits = self.proj_out(decoder_output)
  #   return logits
  def forward(self, mel_spectrogram, decoder_input_ids, encoder_output=None):
    if encoder_output is None:
      encoder_output = self.encoder(mel_spectrogram)

    decoder_output = self.decoder(decoder_input_ids, encoder_output)
    logits = self.proj_out(decoder_output)
    return logits, encoder_output

In [None]:
checkpoint = load_file("/content/model.safetensors")
for key in checkpoint.keys():
    print(f"{key}: {checkpoint[key].shape}")

model.decoder.embed_positions.weight: torch.Size([448, 384])
model.decoder.embed_tokens.weight: torch.Size([51865, 384])
model.decoder.layer_norm.bias: torch.Size([384])
model.decoder.layer_norm.weight: torch.Size([384])
model.decoder.layers.0.encoder_attn.k_proj.weight: torch.Size([384, 384])
model.decoder.layers.0.encoder_attn.out_proj.bias: torch.Size([384])
model.decoder.layers.0.encoder_attn.out_proj.weight: torch.Size([384, 384])
model.decoder.layers.0.encoder_attn.q_proj.bias: torch.Size([384])
model.decoder.layers.0.encoder_attn.q_proj.weight: torch.Size([384, 384])
model.decoder.layers.0.encoder_attn.v_proj.bias: torch.Size([384])
model.decoder.layers.0.encoder_attn.v_proj.weight: torch.Size([384, 384])
model.decoder.layers.0.encoder_attn_layer_norm.bias: torch.Size([384])
model.decoder.layers.0.encoder_attn_layer_norm.weight: torch.Size([384])
model.decoder.layers.0.fc1.bias: torch.Size([1536])
model.decoder.layers.0.fc1.weight: torch.Size([1536, 384])
model.decoder.layers.0.

In [None]:
def load_whisper_weights(my_model, hf_checkpoint_path):
    hf_weights = load_file(hf_checkpoint_path)
    my_state_dict = my_model.state_dict()

    mapping = {
        'model.encoder.conv1.weight': 'encoder.conv1.weight',
        'model.encoder.conv1.bias': 'encoder.conv1.bias',
        'model.encoder.conv2.weight': 'encoder.conv2.weight',
        'model.encoder.conv2.bias': 'encoder.conv2.bias',
        'model.encoder.embed_positions.weight': 'encoder.embed_positions.weight',
        'model.encoder.layer_norm.weight': 'encoder.layer_norm.weight',
        'model.encoder.layer_norm.bias': 'encoder.layer_norm.bias',

        'model.decoder.embed_tokens.weight': 'decoder.embed_tokens.weight',
        'model.decoder.embed_positions.weight': 'decoder.embed_positions.weight',
        'model.decoder.layer_norm.weight': 'decoder.layer_norm.weight',
        'model.decoder.layer_norm.bias': 'decoder.layer_norm.bias',
    }

    for i in range(4):
        layer_mapping = {
            f'model.encoder.layers.{i}.self_attn.k_proj.weight': f'encoder.layers.{i}.self_attn.k.weight',
            f'model.encoder.layers.{i}.self_attn.q_proj.weight': f'encoder.layers.{i}.self_attn.q.weight',
            f'model.encoder.layers.{i}.self_attn.q_proj.bias': f'encoder.layers.{i}.self_attn.q.bias',
            f'model.encoder.layers.{i}.self_attn.v_proj.weight': f'encoder.layers.{i}.self_attn.v.weight',
            f'model.encoder.layers.{i}.self_attn.v_proj.bias': f'encoder.layers.{i}.self_attn.v.bias',
            f'model.encoder.layers.{i}.self_attn.out_proj.weight': f'encoder.layers.{i}.self_attn.out.weight',
            f'model.encoder.layers.{i}.self_attn.out_proj.bias': f'encoder.layers.{i}.self_attn.out.bias',
            f'model.encoder.layers.{i}.self_attn_layer_norm.weight': f'encoder.layers.{i}.self_attn_layer_norm.weight',
            f'model.encoder.layers.{i}.self_attn_layer_norm.bias': f'encoder.layers.{i}.self_attn_layer_norm.bias',
            f'model.encoder.layers.{i}.fc1.weight': f'encoder.layers.{i}.fc1.weight',
            f'model.encoder.layers.{i}.fc1.bias': f'encoder.layers.{i}.fc1.bias',
            f'model.encoder.layers.{i}.fc2.weight': f'encoder.layers.{i}.fc2.weight',
            f'model.encoder.layers.{i}.fc2.bias': f'encoder.layers.{i}.fc2.bias',
            f'model.encoder.layers.{i}.final_layer_norm.weight': f'encoder.layers.{i}.final_layer_norm.weight',
            f'model.encoder.layers.{i}.final_layer_norm.bias': f'encoder.layers.{i}.final_layer_norm.bias',
        }
        mapping.update(layer_mapping)

    for i in range(4):
        layer_mapping = {
            f'model.decoder.layers.{i}.self_attn.k_proj.weight': f'decoder.layers.{i}.self_attn.k.weight',
            f'model.decoder.layers.{i}.self_attn.q_proj.weight': f'decoder.layers.{i}.self_attn.q.weight',
            f'model.decoder.layers.{i}.self_attn.q_proj.bias': f'decoder.layers.{i}.self_attn.q.bias',
            f'model.decoder.layers.{i}.self_attn.v_proj.weight': f'decoder.layers.{i}.self_attn.v.weight',
            f'model.decoder.layers.{i}.self_attn.v_proj.bias': f'decoder.layers.{i}.self_attn.v.bias',
            f'model.decoder.layers.{i}.self_attn.out_proj.weight': f'decoder.layers.{i}.self_attn.out.weight',
            f'model.decoder.layers.{i}.self_attn.out_proj.bias': f'decoder.layers.{i}.self_attn.out.bias',
            f'model.decoder.layers.{i}.self_attn_layer_norm.weight': f'decoder.layers.{i}.self_attn_layer_norm.weight',
            f'model.decoder.layers.{i}.self_attn_layer_norm.bias': f'decoder.layers.{i}.self_attn_layer_norm.bias',
            f'model.decoder.layers.{i}.encoder_attn.k_proj.weight': f'decoder.layers.{i}.encoder_attn.k.weight',
            f'model.decoder.layers.{i}.encoder_attn.q_proj.weight': f'decoder.layers.{i}.encoder_attn.q.weight',
            f'model.decoder.layers.{i}.encoder_attn.q_proj.bias': f'decoder.layers.{i}.encoder_attn.q.bias',
            f'model.decoder.layers.{i}.encoder_attn.v_proj.weight': f'decoder.layers.{i}.encoder_attn.v.weight',
            f'model.decoder.layers.{i}.encoder_attn.v_proj.bias': f'decoder.layers.{i}.encoder_attn.v.bias',
            f'model.decoder.layers.{i}.encoder_attn.out_proj.weight': f'decoder.layers.{i}.encoder_attn.out.weight',
            f'model.decoder.layers.{i}.encoder_attn.out_proj.bias': f'decoder.layers.{i}.encoder_attn.out.bias',
            f'model.decoder.layers.{i}.encoder_attn_layer_norm.weight': f'decoder.layers.{i}.encoder_attn_layer_norm.weight',
            f'model.decoder.layers.{i}.encoder_attn_layer_norm.bias': f'decoder.layers.{i}.encoder_attn_layer_norm.bias',
            f'model.decoder.layers.{i}.fc1.weight': f'decoder.layers.{i}.fc1.weight',
            f'model.decoder.layers.{i}.fc1.bias': f'decoder.layers.{i}.fc1.bias',
            f'model.decoder.layers.{i}.fc2.weight': f'decoder.layers.{i}.fc2.weight',
            f'model.decoder.layers.{i}.fc2.bias': f'decoder.layers.{i}.fc2.bias',
            f'model.decoder.layers.{i}.final_layer_norm.weight': f'decoder.layers.{i}.final_layer_norm.weight',
            f'model.decoder.layers.{i}.final_layer_norm.bias': f'decoder.layers.{i}.final_layer_norm.bias',
        }
        mapping.update(layer_mapping)

    new_state_dict = {}
    for hf_key, my_key in mapping.items():
        if hf_key in hf_weights:
            new_state_dict[my_key] = hf_weights[hf_key]
        else:
            print(f"Missing in HF checkpoint: {hf_key}")

    tied_weights = {'proj_out.weight'}

    missing_keys = set(my_state_dict.keys()) - set(new_state_dict.keys()) - tied_weights
    if missing_keys:
        print(f"Missing keys in your model: {missing_keys}")

    my_model.load_state_dict(new_state_dict, strict=False)
    print("Weights loaded successfully!")
    print("Note: proj_out.weight is tied to decoder.embed_tokens.weight (weight sharing)")

    return my_model

In [None]:
my_model = WhisperModel()
my_model = load_whisper_weights(my_model, "/content/model.safetensors")

Weights loaded successfully!
Note: proj_out.weight is tied to decoder.embed_tokens.weight (weight sharing)


In [None]:
def load_and_preprocess_audio(audio_path):
    processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
    audio, sr = librosa.load(audio_path, sr=16000, mono=True)
    inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
    print("type inputs:", type(inputs))
    print("inputs:", inputs)
    print("inputs.input_features:", inputs.input_features)
    print("inputs.input_features:", inputs["input_features"])
    print("inputs shape:", inputs["input_features"].shape)
    return inputs.input_features

In [None]:
# def generate_transcription(model, mel_spectrogram, tokenizer, max_length=448, device='cpu'):
#     model.eval()
#     model.to(device)
#     mel_spectrogram = mel_spectrogram.to(device)

#     start_tokens = tokenizer.encode(
#         "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
#         add_special_tokens=False
#     )

#     # Initialize with start tokens
#     decoder_input_ids = torch.tensor([start_tokens], device=device)

#     # Get end of text token
#     eot_token_id = tokenizer.encode("<|endoftext|>", add_special_tokens=False)[0]

#     print(f"Start tokens: {start_tokens}")
#     print(f"EOT token: {eot_token_id}")
#     print(f"Starting generation...\n")

#     generated_tokens = []

#     with torch.no_grad():
#         for step in range(max_length - len(start_tokens)):
#             logits = model(mel_spectrogram, decoder_input_ids)
#             next_token_logits = logits[0, -1, :]
#             next_token_id = torch.argmax(next_token_logits, dim=-1).item()
#             token_text = tokenizer.decode([next_token_id])
#             print(f"Step {step}: token_id={next_token_id:5d} | '{token_text}'")
#             decoder_input_ids = torch.cat([
#                 decoder_input_ids,
#                 torch.tensor([[next_token_id]], device=device)
#             ], dim=1)
#             generated_tokens.append(next_token_id)
#             if next_token_id == eot_token_id:
#                 print(f"Generated EOT token at step {step}")
#                 break

#             if step >= 200:
#                 print(f"Reached safety limit of 200 tokens")
#                 break

#     transcription = tokenizer.decode(generated_tokens, skip_special_tokens=True)
#     return transcription

# def generate_transcription(model, mel_spectrogram, tokenizer, max_length=448, device='cpu'):
#     model.eval()
#     model.to(device)
#     mel_spectrogram = mel_spectrogram.to(device)

#     start_tokens = tokenizer.encode(
#         "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
#         add_special_tokens=False
#     )
#     decoder_input_ids = torch.tensor([start_tokens], device=device)
#     eot_token_id = tokenizer.encode("<|endoftext|>", add_special_tokens=False)[0]

#     generated_tokens = []

#     with torch.no_grad():
#         encoder_output = model.encoder(mel_spectrogram)

#         for step in range(max_length - len(start_tokens)):
#             decoder_output = model.decoder(decoder_input_ids, encoder_output)
#             logits = model.proj_out(decoder_output)

#             next_token_logits = logits[0, -1, :]
#             next_token_id = torch.argmax(next_token_logits, dim=-1).item()

#             decoder_input_ids = torch.cat([
#                 decoder_input_ids,
#                 torch.tensor([[next_token_id]], device=device)
#             ], dim=1)

#             generated_tokens.append(next_token_id)

#             if next_token_id == eot_token_id:
#                 break

#             if step >= 200:
#                 break

#     transcription = tokenizer.decode(generated_tokens, skip_special_tokens=True)
#     return transcription

def generate_transcription(model, mel_spectrogram, tokenizer, max_length=448, device="cpu"):
    model.eval()
    model.to(device)
    mel_spectrogram = mel_spectrogram.to(device)

    start_tokens = tokenizer.encode(
        "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
        add_special_tokens=False
    )
    decoder_input_ids = torch.tensor([start_tokens], device=device)
    eot_token_id = tokenizer.encode("<|endoftext|>", add_special_tokens=False)[0]
    generated_tokens = []
    encoder_output = None

    with torch.no_grad():
        for step in range(max_length - len(start_tokens)):
            logits, encoder_output = model(mel_spectrogram, decoder_input_ids, encoder_output)

            next_token_logits = logits[0, -1, :]
            next_token_id = torch.argmax(next_token_logits, dim=-1).item()

            decoder_input_ids = torch.cat([
                decoder_input_ids,
                torch.tensor([[next_token_id]], device=device)
            ], dim=1)

            generated_tokens.append(next_token_id)

            if next_token_id == eot_token_id:
                break

            if step >= 200:
                break

    transcription = tokenizer.decode(generated_tokens, skip_special_tokens=True)
    return transcription

In [None]:
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
audio_path = "/content/harvard.wav"
mel_spec = load_and_preprocess_audio(audio_path)

start = time.time()
transcription = generate_transcription(my_model, mel_spec, tokenizer, device='cpu')
end = time.time()

print(f"\nTranscription: {transcription}")
print(f"Time taken: {end - start:.2f} seconds") # my implementation takes about 8 seconds as compared to hf's 3 seconds.

type inputs: <class 'transformers.feature_extraction_utils.BatchFeature'>
inputs: {'input_features': tensor([[[-0.4328, -0.4328, -0.4328,  ..., -0.4328, -0.4328, -0.4328],
         [-0.4328, -0.4328, -0.4328,  ..., -0.4328, -0.4328, -0.4328],
         [-0.4328, -0.4328, -0.4328,  ..., -0.4328, -0.4328, -0.4328],
         ...,
         [-0.4328, -0.4328, -0.4328,  ..., -0.4328, -0.4328, -0.4328],
         [-0.4328, -0.4328, -0.4328,  ..., -0.4328, -0.4328, -0.4328],
         [-0.4328, -0.4328, -0.4328,  ..., -0.4328, -0.4328, -0.4328]]])}
inputs.input_features: tensor([[[-0.4328, -0.4328, -0.4328,  ..., -0.4328, -0.4328, -0.4328],
         [-0.4328, -0.4328, -0.4328,  ..., -0.4328, -0.4328, -0.4328],
         [-0.4328, -0.4328, -0.4328,  ..., -0.4328, -0.4328, -0.4328],
         ...,
         [-0.4328, -0.4328, -0.4328,  ..., -0.4328, -0.4328, -0.4328],
         [-0.4328, -0.4328, -0.4328,  ..., -0.4328, -0.4328, -0.4328],
         [-0.4328, -0.4328, -0.4328,  ..., -0.4328, -0.4328, -0.

In [None]:
with profile(activities=[ProfilerActivity.CPU], record_shapes=True, profile_memory=True) as prof:
    with record_function("model_inference"):
      transcription = generate_transcription(my_model, mel_spec, tokenizer, device="cpu")
print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=20))

-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                               aten::linear         0.40%      31.858ms        78.50%        6.236s       2.688ms       1.49 GB           0 B          2320  
                               aten::matmul         0.39%      31.123ms        53.82%        4.275s       3.028ms       1.31 GB           0 B          1412  
                                   aten::mm        43.00%        3.416s        43.01%        3.417s       6.726ms     860.32 MB     860.32 MB           508  
                                aten::addmm        3