In [1]:
import torch
from torch import nn

from models.mn.model import mobilenet_v3
from helpers.utils import NAME_TO_WIDTH  # exists in EfficientAT repo

pretrained_name = "mn10_as"       # must match the checkpoint family
width = NAME_TO_WIDTH(pretrained_name)  # e.g., 4.0 for mn40_as
print(width)

num_classes = 527
pretrained_name = None
width_mult = 1.0
reduced_tail = False
dilated = False
strides = (2, 2, 2, 2)
head_type = "mlp"
multihead_attention_heads = 4
input_dim_f = 128
input_dim_t = 300
se_dims = 'c'
se_agg = "max"
se_r= 4

input_dims = (input_dim_f, input_dim_t)
dim_map = {'c': 1, 'f': 2, 't': 3}
assert len(se_dims) <= 3 and all([s in dim_map.keys() for s in se_dims]) or se_dims == 'none'
if se_dims == 'none':
    se_dims = None
else:
    se_dims = [dim_map[s] for s in se_dims]
        
se_conf = dict(se_dims=se_dims, se_agg=se_agg, se_r=se_r)

model = mobilenet_v3(pretrained_name=pretrained_name, num_classes=num_classes,
                 width_mult=width_mult, reduced_tail=reduced_tail, dilated=dilated, strides=strides,
                 head_type=head_type, multihead_attention_heads=multihead_attention_heads,
                 input_dims=input_dims, se_conf=se_conf
                 )


# 2) Load the checkpoint you put into resources/
ckpt_path = "./resources/mn10_as_mels_64_mAP_461.pt"
sd = torch.load(ckpt_path, map_location="cpu")

# 3) If head shape mismatches (e.g., your num_classes=50), load non-strict or drop head keys:
try:
    model.load_state_dict(sd, strict=True)
except RuntimeError:
    model.load_state_dict(sd, strict=False)  # classifier will be ignored if shapes differ

1.0


In [2]:
# print(model)
# print(model.classifier)
# print(model.classifier[5])
# print(model.features[16])

In [3]:
model.classifier[5] = nn.Linear(1280, 50)
print(model)

MN(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (2): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)

In [4]:
import torch
import torch.onnx

# ---- 1) Set your model to eval mode ----
model.eval()

# ---- 2) Create a dummy mel-spectrogram input ----
# Replace these dims with your actual model’s expected input_dims
dummy_input = torch.randn(2, 1, input_dims[0], input_dims[1])  # e.g. (1, 1, 128, 1000)

# ---- 3) ONNX save path ----
onnx_path = "mobilenetv3_audio.onnx"

# ---- 4) Export ----
torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    export_params=True,            # store weights inside the ONNX
    opset_version=17,              # 11–17 works; 12 is most stable for conv models
    do_constant_folding=True,      # optimization
    input_names=["mel"],
    output_names=["logits", "embedding"],
    dynamic_axes={
        "mel": {0: "batch_size"},  # allow dynamic batch size
        "logits": {0: "batch_size"},
        "embedding": {0: "batch_size"}
    }
)

print(f"Export complete: {onnx_path}")


Export complete: mobilenetv3_audio.onnx


In [5]:
model.eval()

x = torch.randn(2, 1, 66, 300)   # (B, C, mel, frames)
with torch.no_grad():
    logits, feats = model(x)          # EfficientAT MN returns (logits, pooled_features)
print(logits.shape)  # expect: [2, 527] for AudioSet pretrain head
print(feats.shape)   # expect: [2, <embed_dim>]

torch.Size([2, 50])
torch.Size([2, 960])


In [7]:
def inspect_state_dict_loading(model, state_dict, prefix=""):
    """
    Utility to compare model parameters with those in a state_dict and print:
      - matched keys
      - missing keys
      - unexpected keys
      - shape mismatches
    """

    model_keys = set(model.state_dict().keys())
    ckpt_keys = set(state_dict.keys())

    matched_keys = []
    missing_keys = sorted(list(model_keys - ckpt_keys))
    unexpected_keys = sorted(list(ckpt_keys - model_keys))
    shape_mismatches = []

    # Check matched keys and tensor shape mismatches
    for k in sorted(list(model_keys & ckpt_keys)):
        model_shape = tuple(model.state_dict()[k].shape)
        ckpt_shape = tuple(state_dict[k].shape)
        if model_shape == ckpt_shape:
            matched_keys.append(k)
        else:
            shape_mismatches.append((k, model_shape, ckpt_shape))

    # Pretty printing
    print("\n" + "="*60)
    print(f"{prefix} STATE DICT LOAD REPORT")
    print("="*60)

    print(f"\n✔ Matched keys ({len(matched_keys)}):")
    for k in matched_keys:
        print(f"  {k}")

    print(f"\n✘ Missing keys ({len(missing_keys)}):")
    for k in missing_keys:
        print(f"  {k}")

    print(f"\n⚠ Unexpected keys ({len(unexpected_keys)}):")
    for k in unexpected_keys:
        print(f"  {k}")

    print(f"\n❗ Shape mismatches ({len(shape_mismatches)}):")
    for k, ms, cs in shape_mismatches:
        print(f"  {k}: model={ms}, checkpoint={cs}")

    print("\nDone.\n" + "="*60 + "\n")


import torch
from models.mn.model import mobilenet_v3
from helpers.utils import NAME_TO_WIDTH  # exists in EfficientAT repo

pretrained_name = "mn10_as"       # must match the checkpoint family
width = NAME_TO_WIDTH(pretrained_name)  # e.g., 4.0 for mn40_as
print(width)

num_classes = 527
pretrained_name = None
width_mult = 1.0
reduced_tail = False
dilated = False
strides = (2, 2, 2, 2)
head_type = "mlp"
multihead_attention_heads = 4
input_dim_f = 128
input_dim_t = 1000
se_dims = 'c'
se_agg = "max"
se_r= 4

input_dims = (input_dim_f, input_dim_t)
dim_map = {'c': 1, 'f': 2, 't': 3}
assert len(se_dims) <= 3 and all([s in dim_map.keys() for s in se_dims]) or se_dims == 'none'
if se_dims == 'none':
    se_dims = None
else:
    se_dims = [dim_map[s] for s in se_dims]
        
se_conf = dict(se_dims=se_dims, se_agg=se_agg, se_r=se_r)

model = mobilenet_v3(pretrained_name=pretrained_name, num_classes=num_classes,
                 width_mult=width_mult, reduced_tail=reduced_tail, dilated=dilated, strides=strides,
                 head_type=head_type, multihead_attention_heads=multihead_attention_heads,
                 input_dims=input_dims, se_conf=se_conf
                 )
# model.classifier[5] = nn.Linear(1280, 50)

state_dict = torch.load(ckpt_path, map_location="cpu")

inspect_state_dict_loading(model, state_dict, prefix="EfficientATMN")
# model.load_state_dict(state_dict, strict=False)


1.0

EfficientATMN STATE DICT LOAD REPORT

✔ Matched keys (310):
  classifier.2.bias
  classifier.2.weight
  features.0.0.weight
  features.0.1.bias
  features.0.1.num_batches_tracked
  features.0.1.running_mean
  features.0.1.running_var
  features.0.1.weight
  features.1.block.0.0.weight
  features.1.block.0.1.bias
  features.1.block.0.1.num_batches_tracked
  features.1.block.0.1.running_mean
  features.1.block.0.1.running_var
  features.1.block.0.1.weight
  features.1.block.1.0.weight
  features.1.block.1.1.bias
  features.1.block.1.1.num_batches_tracked
  features.1.block.1.1.running_mean
  features.1.block.1.1.running_var
  features.1.block.1.1.weight
  features.10.block.0.0.weight
  features.10.block.0.1.bias
  features.10.block.0.1.num_batches_tracked
  features.10.block.0.1.running_mean
  features.10.block.0.1.running_var
  features.10.block.0.1.weight
  features.10.block.1.0.weight
  features.10.block.1.1.bias
  features.10.block.1.1.num_batches_tracked
  features.10.block.1.1

In [None]:
# print(model)

In [None]:
import torch
from models.mn.model import get_model
from helpers.utils import NAME_TO_WIDTH  # exists in EfficientAT repo

pretrained_name = "mn10_as"       # must match the checkpoint family
width = NAME_TO_WIDTH(pretrained_name)  # e.g., 4.0 for mn40_as
print(width)

# 1) Build the architecture with your desired head & dims
model = get_model(
    num_classes=527,     # or 50 if you want to initialize with a fresh classifier
    pretrained_name=None, # <-- don't trigger URL loader
    width_mult=width,
    head_type="mlp",
    input_dim_f=128,
    input_dim_t=300               # ~3s with 10ms hop ≈ 300 frames
)

# 2) Load the checkpoint you put into resources/
ckpt_path = "./resources/mn10_as_mels_64_mAP_461.pt"
sd = torch.load(ckpt_path, map_location="cpu")

# 3) If head shape mismatches (e.g., your num_classes=50), load non-strict or drop head keys:
try:
    model.load_state_dict(sd, strict=True)
except RuntimeError:
    model.load_state_dict(sd, strict=False)  # classifier will be ignored if shapes differ

In [None]:
model.eval()

x = torch.randn(2, 1, 180, 300)   # (B, C, mel, frames)
with torch.no_grad():
    logits, feats = model(x)          # EfficientAT MN returns (logits, pooled_features)
print(logits.shape)  # expect: [2, 527] for AudioSet pretrain head
print(feats.shape)   # expect: [2, <embed_dim>]

In [None]:
import torch
from models.mn.model import get_model
from helpers.utils import NAME_TO_WIDTH  # exists in EfficientAT repo

pretrained_name = "mn10_as"       # must match the checkpoint family
width = NAME_TO_WIDTH(pretrained_name)  # e.g., 4.0 for mn40_as
print(width)

# Build with the **correct width** and your head type
# Tip: set num_classes=50 for ESC-50; the loader will drop the old classifier automatically
model1 = get_model(
    num_classes=50,               # ESC-50
    pretrained_name=pretrained_name,
    width_mult=width,
    head_type="mlp",              # or "fully_convolutional" if that's your choice
    input_dim_f=128,
    input_dim_t=300               # ~3s with 10ms hop ≈ 300 frames
)

In [None]:

import numpy as np
import onnxruntime as ort

onnx_path = "mobilenetv3_audio.onnx"
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])

# Inspect model I/O
print("Inputs:")
for i in sess.get_inputs():
    print(" -", i.name, i.shape, i.type)

print("Outputs:")
for o in sess.get_outputs():
    print(" -", o.name, o.shape, o.type)

# Prepare dummy input (match your export: (N, C, F, T))
# Replace these with the same input_dims used at export time
F, T = input_dims
mel = np.random.randn(1, 1, F, T).astype(np.float32)

# Run inference
outputs = sess.run(None, {"mel": mel})

logits = outputs[0]          # shape: (N, num_classes)
embedding = outputs[1]       # shape: (N, last_channel)
print("logits:", logits.shape, "embedding:", embedding.shape)

# Optional: Softmax to get probabilities
probs = np.exp(logits - logits.max(axis=1, keepdims=True))
probs = probs / probs.sum(axis=1, keepdims=True)
topk_idx = probs[0].argsort()[::-1][:5]
print("Top-5 class ids:", topk_idx)
print("Top-5 probs:", probs[0][topk_idx])



In [8]:
from build_model import build_model_from_efficientat
m = build_model_from_efficientat()
print(m)

MN(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (2): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)