In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from vap.modules.transformer import VapStereoTower
from vap.modules.encoder import EncoderCPC

ModuleNotFoundError: No module named 'x_transformers'

In [None]:
class EncoderCPC(nn.Module):
    def __init__(self, output_dims=256):
        super().__init__()
        self.output_dims = output_dims
        self.conv = nn.Conv1d(1, output_dims, kernel_size=3, padding=1)

    def forward(self, x):
        # Assuming x shape is (batch, channels, samples)
        return self.conv(x)

class VapStereoTower(nn.Module):
    def __init__(self, dim, num_heads, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads) for _ in range(num_layers)
        ])

    def forward(self, x1, x2, return_attn=False):
        # Dummy implementation for attention output
        attn_weights = []
        for layer in self.layers:
            x1 = layer(x1)
            attn = torch.randn(x1.size(0), x1.size(1), x1.size(1))  # Simulating attention weights
            attn_weights.append(attn)
        if return_attn:
            return x1, attn_weights
        return x1

class VAP(nn.Module):
    def __init__(self, encoder, transformer):
        super().__init__()
        self.encoder = encoder
        self.transformer = transformer

    def forward(self, x):
        x1, x2 = self.encoder(x[:, 0:1]), self.encoder(x[:, 1:2])  # Encode each channel
        _, attn_weights = self.transformer(x1, x2, return_attn=True)
        return attn_weights

In [None]:
# Instantiate the model components
encoder = EncoderCPC(output_dims=256)
transformer = VapStereoTower(dim=256, num_heads=8, num_layers=1)

# Instantiate the VAP model
model = VAP(encoder, transformer)

# Generate dummy stereo audio signal
dummy_audio = torch.rand(1, 2, 16000)  # Batch size, channels, samples

# Forward pass to get attention weights
attention_weights = model(dummy_audio)

In [None]:
# Assuming attention_weights is a list of tensors, where each tensor is [batch_size, num_queries, num_keys]
# Visualize the attention weights for the first head of the first layer
sns.heatmap(attention_weights[0][0].detach().numpy(), cmap='viridis', annot=True)
plt.title('Attention Map')
plt.xlabel('Keys')
plt.ylabel('Queries')
plt.show()