In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plotter   # type: ignore[attr-defined]
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from transformers import AutoFeatureExtractor, Wav2Vec2BertModel, AutoProcessor


In [None]:
# from transformers import AutoFeatureExtractor, Wav2Vec2BertModel
# import torch
# from datasets import load_dataset

# dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
# dataset = dataset.sort("id")
# sampling_rate = dataset.features["audio"].sampling_rate

# processor = AutoProcessor.from_pretrained("facebook/w2v-bert-2.0")
# model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0")

# # audio file is decoded on the fly
# inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
# with torch.no_grad():
#     outputs = model(**inputs)


In [None]:
model_id = "openai/whisper-base"    
model = WhisperForConditionalGeneration.from_pretrained(model_id)
processor = WhisperProcessor.from_pretrained(model_id)

In [None]:
import inspect
print(inspect.getsource(model.model.encoder.layers[0].self_attn))

In [None]:
for layer in model.model.encoder.layers:        
    layer.self_attn.is_causal = True            # type: ignore[attr-defined]

In [None]:
audio_5s = torch.randn(16000 * 10).numpy()
audio_7s = np.concatenate([audio_5s, torch.randn(16000 * 6).numpy()])

inputs_5s = processor(audio_5s, sampling_rate=16000, return_tensors="pt") # type: ignore[attr-defined]
inputs_7s = processor(audio_7s, sampling_rate=16000, return_tensors="pt") # type: ignore[attr-defined]

print(f"Mel spectrogram 5s: {inputs_5s.input_features.shape}")
print(f"Mel spectrogram 7s: {inputs_7s.input_features.shape}")

In [None]:
F, axes = plotter.subplots(5, 1, figsize=(15, 12))

for i in range(5):
    feature_5s = inputs_5s.input_features[0, i, :].numpy()
    feature_7s = inputs_7s.input_features[0, i, :].numpy()
    
    axes[i].plot(feature_5s, label='5s audio', alpha=0.7, linewidth=1)
    axes[i].plot(feature_7s, label='7s audio', alpha=0.7, linewidth=1)
    axes[i].set_title(f'Mel Frequency Bin {i}')
    axes[i].set_xlabel('Time Frame')
    axes[i].set_ylabel('Mel Value')
    axes[i].legend()
    axes[i].grid(True, alpha=0.3)

plotter.tight_layout()
plotter.show()

In [None]:
feature_5s = inputs_5s.input_features
feature_7s = inputs_7s.input_features

tgt_len   = feature_5s.shape[-1] // 2                 # 3000 → 1500 after conv

# start with lower-triangular 0/-inf
causal    = torch.tril(torch.ones(tgt_len, tgt_len))

attn_mask = torch.where(causal == 1,
                        torch.tensor(0.0),
                        torch.tensor(float("-inf"))
                       )[None, None, :, :]          # (1,1,tgt,src) additive mask

attn_mask = attn_mask.to(dtype=model.model.encoder.embed_positions.weight.dtype)


model.eval()
# Generate transcription
with torch.no_grad():
    latents_5s = model.model.encoder(feature_5s, attention_mask=attn_mask)
    latents_7s = model.model.encoder(feature_7s, attention_mask=attn_mask)
    

# # Process both through the encoder
# with torch.no_grad():
#     latents_5s = model.model.encoder(feature_5s)
#     latents_7s = model.model.encoder(feature_7s)
    

latents_5s_tensor = latents_5s.last_hidden_state
latents_7s_tensor = latents_7s.last_hidden_state

print(f"5s latents shape: {latents_5s_tensor.shape}")
print(f"7s latents shape: {latents_7s_tensor.shape}")

In [None]:
for i in range(1500):
    l1 = latents_5s_tensor[0][i]
    l2 = latents_7s_tensor[0][i]
    if not np.allclose(l1, l2):
        print(i)
        break

In [None]:
print(latents_5s_tensor[0][150])
print("hello")
print(latents_7s_tensor[0][450])

In [None]:
import inspect
print(inspect.getsource(model.model.encoder.forward))

In [20]:
from transformers import Wav2Vec2BertModel

In [None]:
# Load model
model_id = "facebook/w2v-bert-2.0"
model = Wav2Vec2BertModel.from_pretrained(model_id)

In [24]:
print(model.forward)

<bound method Wav2Vec2BertModel.forward of Wav2Vec2BertModel(
  (feature_projection): Wav2Vec2BertFeatureProjection(
    (layer_norm): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
    (projection): Linear(in_features=160, out_features=1024, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): Wav2Vec2BertEncoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-23): 24 x Wav2Vec2BertEncoderLayer(
        (ffn1_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (ffn1): Wav2Vec2BertFeedForward(
          (intermediate_dropout): Dropout(p=0.0, inplace=False)
          (intermediate_dense): Linear(in_features=1024, out_features=4096, bias=True)
          (intermediate_act_fn): SiLU()
          (output_dense): Linear(in_features=4096, out_features=1024, bias=True)
          (output_dropout): Dropout(p=0.0, inplace=False)
        )
        (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_aff

In [44]:
print(f"Number of encoder layers: {len(model.encoder.layers)}")

lennn = 500

audio_5s = torch.randn((1, 4, 160))
audio_7s = np.concatenate((audio_5s.numpy(), torch.randn((1, 4, 160)).numpy()), axis=1)
audio_7s = torch.from_numpy(audio_7s)

print(f"A: {audio_5s.shape}")
print(f"Long tensor shape: {audio_7s.shape}")

Number of encoder layers: 24
A: torch.Size([1, 4, 160])
Long tensor shape: torch.Size([1, 8, 160])


In [45]:
# Run inference
with torch.no_grad():
    o1 = model(audio_5s)
    o2 = model(audio_7s)

 

In [62]:
print(o1.last_hidden_state.shape)
print(o2.last_hidden_state.shape)

o1h = o1.last_hidden_state
o2h = o2.last_hidden_state

torch.Size([1, 4, 1024])
torch.Size([1, 8, 1024])


In [64]:
for i in range(4):
    print(o1h[0][i])
    print(o2h[0][i])
    print("--------------------------------")

tensor([ 0.0662,  0.0496, -0.0166,  ...,  0.0097,  0.0438,  0.0263])
tensor([ 0.0570,  0.0727,  0.0057,  ...,  0.0039,  0.0041, -0.0060])
--------------------------------
tensor([ 0.0227, -0.0179,  0.0021,  ..., -0.0771,  0.0905,  0.1242])
tensor([-0.0341, -0.0060,  0.0105,  ..., -0.0727,  0.0736,  0.1535])
--------------------------------
tensor([ 0.0565, -0.0304, -0.0194,  ..., -0.0580,  0.0580,  0.0551])
tensor([ 0.0845, -0.0100, -0.0025,  ..., -0.0313,  0.0206,  0.0607])
--------------------------------
tensor([ 0.0040,  0.0202, -0.0066,  ...,  0.0107, -0.0203,  0.0550])
tensor([ 0.0687, -0.0048,  0.0013,  ..., -0.0200,  0.0064,  0.0125])
--------------------------------
