In [None]:
# %load_ext autoreload
# %autoreload 2

# Imports

In [None]:
from torchinfo import summary
from spec_mamba import *

In [None]:
DEVICE = 2
torch.cuda.get_device_name(DEVICE)

In [None]:
sample = torch.randn((4, 1, 128, 65), device=DEVICE)

# Audio Mamba

In [None]:
aum_model = (
    AudioMamba(
        spec_size=(128, 65),
        patch_size=(16, 5),
        channels=1,
        embed_dim=192,
        depth=12,
        mask_ratio=0.5,
        cls_position="none",
        use_rms_norm=True,
        fused_add_norm=True,
        bi_mamba_type="v1",
        output_type="emb",
        ssm_cfg={"d_state": 24, "d_conv": 4, "expand": 3},
    )
    .eval()
    .to(DEVICE)
)
summary(aum_model, input_data=sample, depth=3, device=f"cuda:{DEVICE}")

In [None]:
# Training test
loss_fn = nn.MSELoss()
optimizer = torch.optim.AdamW(aum_model.parameters())

aum_model.train()
optimizer.zero_grad()
logits, mask = aum_model(sample, output_type="emb")
logits, mask = aum_model.reshape_as_spec(logits, mask)
loss = loss_fn(logits, sample)
loss.backward()
optimizer.step()

aum_model.eval()
torch.save(aum_model.state_dict(), "aum_test.pt")

In [None]:
aum_clf_args = AudioMambaCLFArgs(
    num_classes=2,
    spec_size=(128, 65),
    patch_size=(16, 5),
    channels=1,
    embed_dim=192,
    depth=12,
    mask_ratio=0.5,
    cls_position="none",
    use_rms_norm=True,
    fused_add_norm=True,
    bi_mamba_type="v1",
    output_type="mean",
    ssm_cfg={"d_state": 24, "d_conv": 4, "expand": 3},
)
aum_clf_model = AudioMambaCLF(**aum_clf_args).eval().to(DEVICE)
summary(aum_clf_model, input_data=sample, depth=3, device=f"cuda:{DEVICE}")

# SSAST

In [None]:
ssast_model = (
    SSAST(
        spec_size=(128, 65),
        patch_size=(16, 5),
        channels=1,
        embed_dim=192,
        depth=12,
        num_heads=3,
        mlp_ratio=4,
        mask_ratio=0.5,
        cls_position="none",
        use_rms_norm=True,
        output_type="emb",
    )
    .eval()
    .to(DEVICE)
)
summary(ssast_model, input_data=sample, depth=2, device=f"cuda:{DEVICE}")