#### Aim: Trying to visualize the attention in a pretrained VAP model to see whether attention sinks occur here after some sample audios.

In [2]:
import torch
import torchaudio
import matplotlib.pyplot as plt
from vap.modules.lightning_module import VAPModule
from vap.data.datamodule import VAPDataModule
import os
import numpy as np

In [3]:
# Load the pre-trained model
checkpoint_path = "/home/serhan/Desktop/VoiceActivityProjection/example/checkpoints/checkpoint.ckpt"
model = VAPModule.load_from_checkpoint(checkpoint_path)
model.eval()

/home/serhan/miniconda3/envs/vap/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:208: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.


VAPModule(
  (model): VAP(
    (encoder): EncoderCPC(
      (encoder): CPCModel(
        (gEncoder): CPCEncoder(
          (conv0): Conv1d(1, 256, kernel_size=(10,), stride=(5,), padding=(3,))
          (batchNorm0): ChannelNorm()
          (conv1): Conv1d(256, 256, kernel_size=(8,), stride=(4,), padding=(2,))
          (batchNorm1): ChannelNorm()
          (conv2): Conv1d(256, 256, kernel_size=(4,), stride=(2,), padding=(1,))
          (batchNorm2): ChannelNorm()
          (conv3): Conv1d(256, 256, kernel_size=(4,), stride=(2,), padding=(1,))
          (batchNorm3): ChannelNorm()
          (conv4): Conv1d(256, 256, kernel_size=(4,), stride=(2,), padding=(1,))
          (batchNorm4): ChannelNorm()
        )
        (gAR): CPCAR(
          (baseNet): LSTM(256, 256, batch_first=True)
        )
      )
      (downsample): Sequential(
        (0): Rearrange('b t d -> b d t')
        (1): CConv1d(
          256, 256, kernel_size=(5,), stride=(2,)
          (pad): ConstantPad1d(padding=(4, 0

In [4]:
# Prepare a batch of audio samples
audio_dir = "/home/erik/projects/data/Fisher/fisher_eng_tr_sp_d1/audio/000"
audio_files = [f"fe_03_{i:05d}.wav" for i in range(1, 100)]
batch_size = 4
sample_rate = 8000  # Assuming 8kHz sample rate, adjust if different

def load_and_process_audio(file_path):
    waveform, sr = torchaudio.load(file_path)
    if sr != sample_rate:
        waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
    # Assuming stereo audio, if mono, adjust accordingly
    if waveform.shape[0] == 1:
        waveform = waveform.repeat(2, 1)
    # Truncate or pad to 10 seconds
    target_length = 10 * sample_rate
    if waveform.shape[1] < target_length:
        waveform = torch.nn.functional.pad(waveform, (0, target_length - waveform.shape[1]))
    else:
        waveform = waveform[:, :target_length]
    return waveform

In [5]:
# Load a batch of audio samples
batch = []
for i in range(batch_size):
    file_path = os.path.join(audio_dir, audio_files[i])
    waveform = load_and_process_audio(file_path)
    batch.append(waveform)

batch = torch.stack(batch)

In [7]:
# Process the batch through the model
with torch.no_grad():
    outputs = model(batch)

In [12]:
# Function to extract attention weights
def get_attention_weights(module):
    attention_weights = []
    handles = []
    
    def hook(module, input, output):
        attention_weights.append(output[1].detach())
    
    for name, layer in module.named_modules():
        if isinstance(layer, torch.nn.MultiheadAttention):
            handle = layer.register_forward_hook(hook)
            handles.append(handle)
    
    return attention_weights, handles

In [18]:
# Extract attention weights
attention_weights, handles = get_attention_weights(model.model.transformer)

In [19]:
attention_weights

[]

In [14]:
# Run the model again to get attention weights
with torch.no_grad():
    _ = model(batch)

In [15]:
# Remove the hooks
for handle in handles:
    handle.remove()

In [16]:
# Visualize attention weights
num_layers = len(attention_weights)
fig, axes = plt.subplots(num_layers, 1, figsize=(10, 5 * num_layers))
if num_layers == 1:
    axes = [axes]

for layer, weights in enumerate(attention_weights):
    avg_weights = weights.mean(0).cpu().numpy()
    im = axes[layer].imshow(avg_weights, aspect='auto', cmap='viridis')
    axes[layer].set_title(f'Average Attention Weights - Layer {layer+1}')
    axes[layer].set_xlabel('Key Position')
    axes[layer].set_ylabel('Query Position')
    fig.colorbar(im, ax=axes[layer])

plt.tight_layout()
plt.savefig('attention_weights.png')
plt.close()

# Analyze attention to initial tokens
initial_token_attention = [weights[:, :, :4].mean(dim=(0, 1)).cpu().numpy() for weights in attention_weights]

fig, ax = plt.subplots(figsize=(10, 5))
for layer, attn in enumerate(initial_token_attention):
    ax.plot(range(4), attn, label=f'Layer {layer+1}')
ax.set_xlabel('Initial Token Position')
ax.set_ylabel('Average Attention Weight')
ax.set_title('Attention to Initial Tokens Across Layers')
ax.legend()
plt.savefig('initial_token_attention.png')
plt.close()

print("Analysis complete. Check 'attention_weights.png' and 'initial_token_attention.png' for visualizations.")

# Print statistics about attention to initial tokens
print("\nAverage attention to initial tokens:")
for layer, attn in enumerate(initial_token_attention):
    print(f"Layer {layer+1}: {attn.mean():.4f}")

print("\nAttention sink analysis:")
for layer, attn in enumerate(initial_token_attention):
    if attn[0] > attn[1:].mean() * 1.5:  # Arbitrary threshold, adjust as needed
        print(f"Layer {layer+1} shows strong attention sink behavior")
    else:
        print(f"Layer {layer+1} does not show strong attention sink behavior")

ValueError: Number of rows must be a positive integer, not 0

<Figure size 1000x0 with 0 Axes>

#### 2nd Try

In [32]:
import torch
import matplotlib.pyplot as plt
from vap.modules.lightning_module import VAPModule
from vap.modules.modules import MultiHeadAttentionAlibi
import math

In [33]:
# Load the pre-trained model
checkpoint_path = "/home/serhan/Desktop/VoiceActivityProjection/example/checkpoints/checkpoint.ckpt"
model = VAPModule.load_from_checkpoint(checkpoint_path)
model.eval()

VAPModule(
  (model): VAP(
    (encoder): EncoderCPC(
      (encoder): CPCModel(
        (gEncoder): CPCEncoder(
          (conv0): Conv1d(1, 256, kernel_size=(10,), stride=(5,), padding=(3,))
          (batchNorm0): ChannelNorm()
          (conv1): Conv1d(256, 256, kernel_size=(8,), stride=(4,), padding=(2,))
          (batchNorm1): ChannelNorm()
          (conv2): Conv1d(256, 256, kernel_size=(4,), stride=(2,), padding=(1,))
          (batchNorm2): ChannelNorm()
          (conv3): Conv1d(256, 256, kernel_size=(4,), stride=(2,), padding=(1,))
          (batchNorm3): ChannelNorm()
          (conv4): Conv1d(256, 256, kernel_size=(4,), stride=(2,), padding=(1,))
          (batchNorm4): ChannelNorm()
        )
        (gAR): CPCAR(
          (baseNet): LSTM(256, 256, batch_first=True)
        )
      )
      (downsample): Sequential(
        (0): Rearrange('b t d -> b d t')
        (1): CConv1d(
          256, 256, kernel_size=(5,), stride=(2,)
          (pad): ConstantPad1d(padding=(4, 0

In [40]:
# Inspect model structure
def print_model_structure(model, indent=0):
    for name, module in model.named_children():
        print('  ' * indent + f"{name}: {type(module)}")
        if list(module.children()):
            print_model_structure(module, indent + 1)

print("Model structure:")
print_model_structure(model.model)

Model structure:
encoder: <class 'vap.modules.encoder.EncoderCPC'>
  encoder: <class 'vap.modules.encoder_components.CPCModel'>
    gEncoder: <class 'vap.modules.encoder_components.CPCEncoder'>
      conv0: <class 'torch.nn.modules.conv.Conv1d'>
      batchNorm0: <class 'vap.modules.encoder_components.ChannelNorm'>
      conv1: <class 'torch.nn.modules.conv.Conv1d'>
      batchNorm1: <class 'vap.modules.encoder_components.ChannelNorm'>
      conv2: <class 'torch.nn.modules.conv.Conv1d'>
      batchNorm2: <class 'vap.modules.encoder_components.ChannelNorm'>
      conv3: <class 'torch.nn.modules.conv.Conv1d'>
      batchNorm3: <class 'vap.modules.encoder_components.ChannelNorm'>
      conv4: <class 'torch.nn.modules.conv.Conv1d'>
      batchNorm4: <class 'vap.modules.encoder_components.ChannelNorm'>
    gAR: <class 'vap.modules.encoder_components.CPCAR'>
      baseNet: <class 'torch.nn.modules.rnn.LSTM'>
  downsample: <class 'torch.nn.modules.container.Sequential'>
    0: <class 'einops.

In [51]:
def get_attention_weights(module):
    attention_weights = []
    handles = []
    
    def hook(module, input, output):
        # Try to access attention weights directly from the module
        if hasattr(module, 'last_attn_weights'):
            attention_weights.append(module.last_attn_weights.detach())
    
    for name, layer in module.named_modules():
        if isinstance(layer, MultiHeadAttentionAlibi):
            handle = layer.register_forward_hook(hook)
            handles.append(handle)
    
    return attention_weights, handles

In [52]:
# Extract attention weights
attention_weights, handles = get_attention_weights(model.model.transformer)

In [53]:
# Prepare the correct input shape
dummy_input = torch.randn(1, 2, 16000)  # Batch size 1, 2 channels, 16000 samples (1 second at 16kHz)
print("Input shape:", dummy_input.shape)

# Process the dummy input through the model
with torch.no_grad():
    try:
        outputs = model(dummy_input)
        print("Forward pass successful")
    except Exception as e:
        print(f"Error during model forward pass: {str(e)}")

# Remove the hooks
for handle in handles:
    handle.remove()

print(f"Number of attention weights captured: {len(attention_weights)}")

if attention_weights:
    print("Attention weights captured successfully")
    # Here you can add visualization code if needed
else:
    print("No attention weights were captured. The model might not store attention weights directly.")

# Print model's structure
print("\nModel structure:")
def print_model_structure(model, depth=0):
    for name, module in model.named_children():
        print("  " * depth + f"{name}: {type(module)}")
        if isinstance(module, MultiHeadAttentionAlibi):
            print("  " * (depth+1) + f"Attention layer found: {name}")
        print_model_structure(module, depth + 1)

print_model_structure(model.model)

# If no attention weights were captured, let's print the output of the model
if not attention_weights:
    print("\nModel output:")
    for key, value in outputs.items():
        print(f"{key}: {value.shape}")

Input shape: torch.Size([1, 2, 16000])
Error during model forward pass: not enough values to unpack (expected 3, got 0)
Number of attention weights captured: 0
No attention weights were captured. The model might not store attention weights directly.

Model structure:
encoder: <class 'vap.modules.encoder.EncoderCPC'>
  encoder: <class 'vap.modules.encoder_components.CPCModel'>
    gEncoder: <class 'vap.modules.encoder_components.CPCEncoder'>
      conv0: <class 'torch.nn.modules.conv.Conv1d'>
      batchNorm0: <class 'vap.modules.encoder_components.ChannelNorm'>
      conv1: <class 'torch.nn.modules.conv.Conv1d'>
      batchNorm1: <class 'vap.modules.encoder_components.ChannelNorm'>
      conv2: <class 'torch.nn.modules.conv.Conv1d'>
      batchNorm2: <class 'vap.modules.encoder_components.ChannelNorm'>
      conv3: <class 'torch.nn.modules.conv.Conv1d'>
      batchNorm3: <class 'vap.modules.encoder_components.ChannelNorm'>
      conv4: <class 'torch.nn.modules.conv.Conv1d'>
      batch

In [50]:
if attention_weights:
    # Visualize attention weights
    num_layers = len(attention_weights)
    fig, axes = plt.subplots(num_layers, 1, figsize=(10, 5 * num_layers))
    if num_layers == 1:
        axes = [axes]

    for layer, weights in enumerate(attention_weights):
        avg_weights = weights.mean(0).cpu().numpy()
        im = axes[layer].imshow(avg_weights, aspect='auto', cmap='viridis')
        axes[layer].set_title(f'Average Attention Weights - Layer {layer+1}')
        axes[layer].set_xlabel('Key Position')
        axes[layer].set_ylabel('Query Position')
        fig.colorbar(im, ax=axes[layer])

    plt.tight_layout()
    plt.savefig('attention_weights.png')
    plt.close()

    print("Visualization complete. Check 'attention_weights.png' for the result.")
else:
    print("No attention weights were captured. The model might not use standard attention layers.")

No attention weights were captured. The model might not use standard attention layers.
