In [1]:
import onnx



In [3]:
m = onnx.load("/home/yehoshua/projects/vad/ten-vad/src/onnx_model/ten-vad.onnx")

In [4]:
import onnx
from onnx import numpy_helper, shape_inference, helper


m = shape_inference.infer_shapes(m)          # fill in missing tensor shapes

print("IR version :", m.ir_version)
print("Opsets      :", [(o.domain or 'ai.onnx', o.version) for o in m.opset_import])

# Inputs / outputs ----------------------------------------------------------
for value in list(m.graph.input) + list(m.graph.output):
    t = value.type.tensor_type
    shape = [d.dim_value or '?' for d in t.shape.dim]
    print(f"{value.name:20}  {onnx.TensorProto.DataType.Name(t.elem_type):6}  {shape}")

# Nodes ---------------------------------------------------------------------
for i,node in enumerate(m.graph.node):
    print(f"{i:3d} {node.op_type:15}  {node.name or ''}")
    for a in node.attribute:
        # GRU / LSTM have many attributes - shorten output if needed
        if a.type == onnx.AttributeProto.INT or a.type == onnx.AttributeProto.FLOAT:
            print(f"     ├─ attr {a.name}: {helper.get_attribute_value(a)}")

IR version : 4
Opsets      : [('ai.onnx', 9), ('ai.onnx.ml', 2)]
input_1               FLOAT   ['?', 3, 41]
input_2               FLOAT   ['?', 64]
input_3               FLOAT   ['?', 64]
input_6               FLOAT   ['?', 64]
input_7               FLOAT   ['?', 64]
output_1              FLOAT   ['?', '?', 1]
output_2              FLOAT   [1, 64]
output_3              FLOAT   [1, 64]
output_6              FLOAT   [1, 64]
output_7              FLOAT   [1, 64]
  0 Unsqueeze        Unsqueeze__97
  1 Unsqueeze        Unsqueeze__95
  2 Unsqueeze        Unsqueeze__68
  3 Unsqueeze        Unsqueeze__66
  4 Unsqueeze        StatefulPartitionedCall/vad_model/ExpandDims
  5 Reshape          StatefulPartitionedCall/vad_model/separable_conv2d/separable_conv2d/depthwise__111
  6 Conv             StatefulPartitionedCall/vad_model/separable_conv2d/separable_conv2d/depthwise
     ├─ attr group: 1
  7 Conv             StatefulPartitionedCall/vad_model/separable_conv2d/BiasAdd
     ├─ attr group: 1
  8

In [5]:
from onnx import numpy_helper
param = {t.name: numpy_helper.to_array(t) for t in m.graph.initializer}

for k,v in param.items():
    print(f"{k:55}  {v.shape}")

new_shape__175                                           (4,)
const_fold_opt__178                                      (1, 1, 3, 3)
StatefulPartitionedCall/vad_model/separable_conv2d/separable_conv2d/ReadVariableOp_1:0  (16, 1, 1, 1)
StatefulPartitionedCall/vad_model/separable_conv2d/BiasAdd/ReadVariableOp:0  (16,)
const_fold_opt__179                                      (16, 1, 1, 3)
StatefulPartitionedCall/vad_model/separable_conv1d/ExpandDims_2:0  (16, 16, 1, 1)
StatefulPartitionedCall/vad_model/separable_conv1d/BiasAdd/ReadVariableOp:0  (16,)
const_fold_opt__180                                      (16, 1, 1, 3)
StatefulPartitionedCall/vad_model/separable_conv1d_1/ExpandDims_2:0  (16, 16, 1, 1)
StatefulPartitionedCall/vad_model/separable_conv1d_1/BiasAdd/ReadVariableOp:0  (16,)
new_shape__177                                           (3,)
W0__70                                                   (1, 256, 80)
R0__71                                                   (1, 256, 64)
B0__7

In [10]:
import torch
import torch.nn as nn

class TenVAD(nn.Module):
    def __init__(self):
        super().__init__()
        # ─── Conv front-end ──────────────────────────────────────────
        self.conv_dw = nn.Conv2d(1, 1, kernel_size=(3,3),
                         padding=(0,1),  # <── only width is padded
                         bias=False)
        self.conv_pw = nn.Conv2d(1, 16, kernel_size=1, bias=True)
        self.relu    = nn.ReLU()
        self.pool    = nn.MaxPool2d((1,2))

        self.sep1_dw = nn.Conv2d(16, 16, kernel_size=(1,3),
                                padding=(0,0), groups=16, bias=False)
        self.sep1_pw = nn.Conv2d(16, 16, kernel_size=1,  bias=True)

        self.sep2_dw = nn.Conv2d(16, 16, kernel_size=(1,3),
                                padding=(0,0), groups=16, bias=False)
        self.sep2_pw = nn.Conv2d(16, 16, kernel_size=1,  bias=True)

        # ─── RNN core ────────────────────────────────────────────────
        self.lstm1 = nn.LSTM(64, 64, batch_first=True)
        self.lstm2 = nn.LSTM(64, 64, batch_first=True)

        # ─── Densities ───────────────────────────────────────────────
        self.fc1 = nn.Linear(128, 128)
        self.fc2 = nn.Linear(128, 1)
        self.sig = nn.Sigmoid()

    def forward(self, x, h1=None, c1=None, h2=None, c2=None):
        # x: (B, 3, 41)  → (B, 1, 3, 41)
        B = x.size(0)
        x = x.unsqueeze(1)

        x = self.conv_dw(x)          # (B, 1, 1, 41)
        x = self.conv_pw(x)          # (B,16, 1, 41)
        x = self.relu(x)
        x = self.pool(x)             # (B,16, 1, 20)

        # ─ separable_conv1d (really 2-D) ─
        x = self.sep1_dw(x)          # (B,16,1,18)
        x = self.sep1_pw(x)
        x = self.relu(x)
        x = x.squeeze(2)             # (B,16,18)

        # put the singleton height back for the second block
        x = x.unsqueeze(2)           # (B,16,1,18)
        x = self.sep2_dw(x)          # (B,16,1,16)
        x = self.sep2_pw(x)
        x = self.relu(x)
        x = x.squeeze(2)             # (B,16,16)

        # transpose for the LSTM that follows later
        x = x.permute(0, 2, 1)       # (B, 16, 16)  -> think of 16 frames × 16 feat

        # LSTM stack
        h1 = torch.zeros(1,B,64, device=x.device) if h1 is None else h1
        c1 = torch.zeros(1,B,64, device=x.device) if c1 is None else c1
        x, (h1, c1) = self.lstm1(x, (h1, c1))

        h2 = torch.zeros(1,B,64, device=x.device) if h2 is None else h2
        c2 = torch.zeros(1,B,64, device=x.device) if c2 is None else c2
        x, (h2, c2) = self.lstm2(x, (h2, c2))

        # concat last outputs of both directions (here unidirectional so just x)
        x = torch.cat([x], dim=2)   # (B, T, 64) → keep for dense

        # dense layers
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return self.sig(x), (h1,c1,h2,c2)

In [11]:
dummy = torch.randn(5,3,41)
out, _ = TenVAD()(dummy)
print(out.shape)   # should be (5, 20, 1)  → matches ONNX output_1

RuntimeError: input.size(-1) must be equal to input_size. Expected 64, got 16