In [1]:
# %load_ext autoreload
# %autoreload 2

# Imports

In [2]:
from torchinfo import summary
from spec_mamba import *

In [3]:
DEVICE = 2
torch.cuda.get_device_name(DEVICE)

'Quadro RTX 6000'

In [4]:
sample = torch.randn((4, 1, 128, 65), device=DEVICE)

# Audio Mamba

In [5]:
aum_model = (
    AudioMamba(
        spec_size=(128, 65),
        patch_size=(16, 5),
        channels=1,
        embed_dim=192,
        depth=12,
        mask_ratio=0.5,
        cls_position="none",
        use_rms_norm=True,
        fused_add_norm=True,
        bi_mamba_type="v1",
        output_type="emb",
        ssm_cfg={"d_state": 24, "d_conv": 4, "expand": 3},
    )
    .eval()
    .to(DEVICE)
)
summary(aum_model, input_data=sample, depth=3, device=f"cuda:{DEVICE}")

Layer (type:depth-idx)                   Output Shape              Param #
AudioMamba                               [4, 104, 80]              384
├─FlexiPatchEmbed: 1-1                   [4, 104, 192]             15,552
│    └─Identity: 2-1                     [4, 104, 192]             --
├─FlexiPosEmbed: 1-2                     [4, 104, 192]             19,968
├─Dropout: 1-3                           [4, 104, 192]             --
├─ModuleList: 1-4                        --                        --
│    └─MambaBlock: 2-2                   [4, 104, 192]             192
│    │    └─Mamba: 3-1                   [4, 104, 192]             404,928
│    └─MambaBlock: 2-3                   [4, 104, 192]             192
│    │    └─DropPath: 3-2                [4, 104, 192]             --
│    │    └─Mamba: 3-3                   [4, 104, 192]             404,928
│    └─MambaBlock: 2-4                   [4, 104, 192]             192
│    │    └─DropPath: 3-4                [4, 104, 192]         

In [6]:
# Training test
loss_fn = nn.MSELoss()
optimizer = torch.optim.AdamW(aum_model.parameters())

aum_model.train()
optimizer.zero_grad()
logits, mask = aum_model(sample, output_type="emb")
logits, mask = aum_model.reshape_as_spec(logits, mask)
loss = loss_fn(logits, sample)
loss.backward()
optimizer.step()

aum_model.eval()
torch.save(aum_model.state_dict(), "aum_test.pt")

In [7]:
aum_clf_args = AudioMambaCLFArgs(
    num_classes=2,
    spec_size=(128, 65),
    patch_size=(16, 5),
    channels=1,
    embed_dim=192,
    depth=12,
    mask_ratio=0.5,
    cls_position="none",
    use_rms_norm=True,
    fused_add_norm=True,
    bi_mamba_type="v1",
    output_type="mean",
    ssm_cfg={"d_state": 24, "d_conv": 4, "expand": 3},
)
aum_clf_model = AudioMambaCLF(**aum_clf_args).eval().to(DEVICE)
summary(aum_clf_model, input_data=sample, depth=3, device=f"cuda:{DEVICE}")

Layer (type:depth-idx)                   Output Shape              Param #
AudioMambaCLF                            [4, 2]                    --
├─AudioMamba: 1-1                        --                        384
│    └─FlexiPatchEmbed: 2-1              [4, 104, 192]             15,552
│    │    └─Identity: 3-1                [4, 104, 192]             --
│    └─FlexiPosEmbed: 2-2                [4, 104, 192]             19,968
│    └─Dropout: 2-3                      [4, 104, 192]             --
│    └─ModuleList: 2-4                   --                        --
│    │    └─MambaBlock: 3-2              [4, 104, 192]             405,120
│    │    └─MambaBlock: 3-3              [4, 104, 192]             405,120
│    │    └─MambaBlock: 3-4              [4, 104, 192]             405,120
│    │    └─MambaBlock: 3-5              [4, 104, 192]             405,120
│    │    └─MambaBlock: 3-6              [4, 104, 192]             405,120
│    │    └─MambaBlock: 3-7              [4, 104, 1

# SSAST

In [8]:
ssast_model = (
    SSAST(
        spec_size=(128, 65),
        patch_size=(16, 5),
        channels=1,
        embed_dim=192,
        depth=12,
        num_heads=3,
        mlp_ratio=4,
        mask_ratio=0.5,
        cls_position="none",
        use_rms_norm=True,
        output_type="emb",
    )
    .eval()
    .to(DEVICE)
)
summary(ssast_model, input_data=sample, depth=2, device=f"cuda:{DEVICE}")

Layer (type:depth-idx)                   Output Shape              Param #
SSAST                                    [4, 104, 80]              192
├─FlexiPatchEmbed: 1-1                   [4, 104, 192]             15,552
│    └─Identity: 2-1                     [4, 104, 192]             --
├─FlexiPosEmbed: 1-2                     [4, 104, 192]             19,968
├─Dropout: 1-3                           [4, 104, 192]             --
├─ModuleList: 1-4                        --                        --
│    └─SSASTBlock: 2-2                   [4, 104, 192]             444,480
│    └─SSASTBlock: 2-3                   [4, 104, 192]             444,480
│    └─SSASTBlock: 2-4                   [4, 104, 192]             444,480
│    └─SSASTBlock: 2-5                   [4, 104, 192]             444,480
│    └─SSASTBlock: 2-6                   [4, 104, 192]             444,480
│    └─SSASTBlock: 2-7                   [4, 104, 192]             444,480
│    └─SSASTBlock: 2-8                   [4, 1