In [19]:
from torchvision.models.vision_transformer import VisionTransformer
import torch
import numpy as np
import torch.nn as nn

freq = 64
t = 400
chan = 256
b = 1

device = torch.device('mps')
x = torch.randn((b, chan, freq, t)).to(device)
print(x.shape)

torch.Size([1, 256, 64, 400])


In [21]:
class SpatioSpectralTransformer(nn.Module):
    def __init__(self, n_channels, n_freq_bins, n_mel_bands, d_model=256, nhead=8):
        super().__init__()
        
        # Project MEG channels and frequencies into embedding space
        self.spatial_embedding = nn.Conv2d(n_channels, d_model, kernel_size=1)
        
        # Transformer for cross-channel attention
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=d_model*4,
                batch_first=True
            ),
            num_layers=3
        )
        
        # Final projection to mel bands
        self.output_proj = nn.Conv2d(d_model, n_mel_bands, kernel_size=3, padding=1)
        
    def forward(self, x):
        # x shape: [B, C, freq, T]
        B, C, F, T = x.shape
        
        # Project channels
        x = self.spatial_embedding(x)  # [B, d_model, freq, T]
        print(f'x shape after spatial embedding: {x.shape}')
        
        # Reshape for transformer
        x = x.permute(0, 3, 2, 1)  # [B, T, freq, d_model]
        x = x.reshape(B*T, F, -1)   # [B*T, freq, d_model]
        print(f'x shape after reshaping: {x.shape}')
        
        # Apply transformer
        x = self.transformer(x)  # [B*T, freq, d_model]
        
        # Reshape back
        x = x.reshape(B, T, F, -1)  # [B, T, freq, d_model]
        x = x.permute(0, 3, 2, 1)   # [B, d_model, freq, T]
        
        # Project to mel bands
        x = self.output_proj(x)      # [B, mel_bands, freq, T]
        
        return x
    
model = SpatioSpectralTransformer(n_channels=chan, n_freq_bins=freq, n_mel_bands=128, d_model=256, nhead=8).to(device)
y = model(x)
print(y.shape)

x shape after spatial embedding: torch.Size([1, 256, 64, 400])
x shape after reshaping: torch.Size([400, 64, 256])
torch.Size([1, 128, 64, 400])


In [None]:
# model = VisionTransformer(
#     image_size=512,
#     patch_size=16,
#     num_layers=1,
#     num_heads=8,
#     hidden_dim=256,
#     mlp_dim=512,
# ).to(device)