In [87]:
import torch
import numpy as np
from torch import nn, Tensor
from transformers import WhisperForConditionalGeneration
from collections import defaultdict
from dataclasses import dataclass

In [88]:
# Path to your checkpoint
checkpoint_path = "output/custom_librispeech_test/checkpoint-490"

# Load the model
model = WhisperForConditionalGeneration.from_pretrained(checkpoint_path)

  return torch.load(checkpoint_file, map_location="cpu")


In [89]:
# Print model structure
print("Model structure:\n")
print(model)

Model structure:

WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(80, 384, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(384, 384, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 384)
      (layers): ModuleList(
        (0-3): 4 x WhisperEncoderLayer(
          (self_attn): WhisperAttention(
            (k_proj): Linear(in_features=384, out_features=384, bias=False)
            (v_proj): Linear(in_features=384, out_features=384, bias=True)
            (q_proj): Linear(in_features=384, out_features=384, bias=True)
            (out_proj): Linear(in_features=384, out_features=384, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (fc2): Linear(in_features=1536, out_features=384, bias=Tr

In [90]:
# Print module types
module_names = defaultdict(int)

# Iterating through all of the modules
for name, module in model.named_modules():
    module_names[type(module).__name__] += 1

# Print cleanly
print("\nModule type".ljust(40), "Count")
print("-" * 55)
for module_type, count in sorted(module_names.items(), key=lambda x: x[0]):
    print(f"{module_type.ljust(40)} {count}")


Module type                             Count
-------------------------------------------------------
Conv1d                                   2
Embedding                                2
GELUActivation                           8
LayerNorm                                22
Linear                                   65
ModuleList                               2
WhisperAttention                         12
WhisperDecoder                           1
WhisperDecoderLayer                      4
WhisperEncoder                           1
WhisperEncoderLayer                      4
WhisperForConditionalGeneration          1
WhisperModel                             1
WhisperPositionalEmbedding               1


In [91]:
# Define a dataclass to store the cache
@dataclass
class OBSLinearCache:
    # Initilize empty cache
    name: str = None
    weight: Tensor = None
    input: Tensor = None
    output: Tensor = None
    module: nn.Linear = None

# Define a function to populate the cache
def get_layer_hook(name: str):
    # Create instance of cache
    cache = OBSLinearCache()

    def hook_fn(module, args, outputs):
        # Update cache
        cache.module = module
        cache.name = name
        cache.input = args
        cache.output = outputs
        if hasattr(module, "weight"):
            cache.weight = module.weight

    return hook_fn, cache        

In [99]:
# Create caches and hooks throughout the model
caches = {}
hooks = {}

for name, module in model.named_modules():
    hook_fn, cache = get_layer_hook(name)
    caches[name] = cache
    hooks[name] = module.register_forward_hook(hook_fn)

print(caches)

{'': OBSLinearCache(name=None, weight=None, input=None, output=None, module=None), 'model': OBSLinearCache(name=None, weight=None, input=None, output=None, module=None), 'model.encoder': OBSLinearCache(name=None, weight=None, input=None, output=None, module=None), 'model.encoder.conv1': OBSLinearCache(name=None, weight=None, input=None, output=None, module=None), 'model.encoder.conv2': OBSLinearCache(name=None, weight=None, input=None, output=None, module=None), 'model.encoder.embed_positions': OBSLinearCache(name=None, weight=None, input=None, output=None, module=None), 'model.encoder.layers': OBSLinearCache(name=None, weight=None, input=None, output=None, module=None), 'model.encoder.layers.0': OBSLinearCache(name=None, weight=None, input=None, output=None, module=None), 'model.encoder.layers.0.self_attn': OBSLinearCache(name=None, weight=None, input=None, output=None, module=None), 'model.encoder.layers.0.self_attn.k_proj': OBSLinearCache(name=None, weight=None, input=None, output=N

In [None]:
# Create dummy input
dummy_input_features = torch.randn(1, 80, 3000)  

# Feed dummy input through the model
with torch.no_grad():
    outputs = model.generate(dummy_input_features, max_length=448)
print("Model forward pass successful!")
print(f"Generated output shape: {outputs.shape}")
print(caches.keys())

Model forward pass successful!
Generated output shape: torch.Size([1, 6])
