© 2025, Stefan Webb. Some Rights Reserved.

Except where otherwise noted, this work is licensed under a
[Creative Commons Attribution-ShareAlike 4.0 International (CC BY-SA 4.0)](https://creativecommons.org/licenses/by-sa/4.0/deed.en)

# Validating Silero VAD Implementation
We can validate our implementation by loading the pre-trained weights into the PyTorch `nn.Module` and comparing the output to the saved `jit.RecursiveScriptModule`.

The purpose of this notebook is to help me develop the tests for the modules making up our implementation.

In [1]:
import torch
import open_vad
from open_vad import SileroVAD
from open_vad.utils import count_parameters

## Create a model from our implementation and load the pretrained JIT version

In [18]:
model = SileroVAD()
model.eval()
print(count_parameters(model))

309633


In [3]:
!wget -P ../models/ https://github.com/snakers4/silero-vad/raw/refs/heads/master/src/silero_vad/data/silero_vad.jit

--2025-05-27 23:57:19--  https://github.com/snakers4/silero-vad/raw/refs/heads/master/src/silero_vad/data/silero_vad.jit
Resolving github.com (github.com)... 140.82.116.3
Connecting to github.com (github.com)|140.82.116.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/snakers4/silero-vad/refs/heads/master/src/silero_vad/data/silero_vad.jit [following]
--2025-05-27 23:57:20--  https://raw.githubusercontent.com/snakers4/silero-vad/refs/heads/master/src/silero_vad/data/silero_vad.jit
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8000::154, 2606:50c0:8002::154, 2606:50c0:8003::154, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8000::154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2269612 (2.2M) [application/octet-stream]
Saving to: ‘../models/silero_vad.jit.2’


2025-05-27 23:57:20 (20.4 MB/s) - ‘../models/silero_vad.jit.2’ saved

In [3]:
jit_model = torch.jit.load("../models/silero_vad.jit")
jit_model.eval()
print(count_parameters(jit_model))

545282


As we will see, the difference in parameters count is because JIT includes models for both 16k and 8k sampling rates.

## Compare the parameter/buffer names and sizes

In [4]:
for n, v in model.named_parameters():
    print(n, v.numel())

encoder.0.reparam_conv.weight 49536
encoder.0.reparam_conv.bias 128
encoder.1.reparam_conv.weight 24576
encoder.1.reparam_conv.bias 64
encoder.2.reparam_conv.weight 12288
encoder.2.reparam_conv.bias 64
encoder.3.reparam_conv.weight 24576
encoder.3.reparam_conv.bias 128
decoder.rnn.weight_ih 65536
decoder.rnn.weight_hh 65536
decoder.rnn.bias_ih 512
decoder.rnn.bias_hh 512
decoder.decoder.2.weight 128
decoder.decoder.2.bias 1


In [5]:
for n, v in [(n, v) for (n, v) in jit_model.named_parameters() if not n.startswith("_model_8k")]:
    print(n, v.numel())

_model.encoder.0.reparam_conv.weight 49536
_model.encoder.0.reparam_conv.bias 128
_model.encoder.1.reparam_conv.weight 24576
_model.encoder.1.reparam_conv.bias 64
_model.encoder.2.reparam_conv.weight 12288
_model.encoder.2.reparam_conv.bias 64
_model.encoder.3.reparam_conv.weight 24576
_model.encoder.3.reparam_conv.bias 128
_model.decoder.rnn.weight_ih 65536
_model.decoder.rnn.weight_hh 65536
_model.decoder.rnn.bias_ih 512
_model.decoder.rnn.bias_hh 512
_model.decoder.decoder.2.weight 128
_model.decoder.decoder.2.bias 1


In [6]:
for x in jit_model.named_buffers():
    print(x)

('_model.stft.forward_basis_buffer', tensor([[[ 0.0000e+00,  1.5059e-04,  6.0227e-04,  ...,  1.3548e-03,
           6.0227e-04,  1.5059e-04]],

        [[ 0.0000e+00,  1.5055e-04,  6.0155e-04,  ...,  1.3511e-03,
           6.0155e-04,  1.5055e-04]],

        [[ 0.0000e+00,  1.5041e-04,  5.9937e-04,  ...,  1.3401e-03,
           5.9937e-04,  1.5041e-04]],

        ...,

        [[ 0.0000e+00, -7.3891e-06,  5.9033e-05,  ...,  1.9879e-04,
          -5.9033e-05,  7.3891e-06]],

        [[ 0.0000e+00, -3.6957e-06,  2.9552e-05,  ...,  9.9663e-05,
          -2.9552e-05,  3.6957e-06]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]]]))
('_model_8k.stft.forward_basis_buffer', tensor([[[ 0.0000e+00,  6.0227e-04,  2.4076e-03,  ...,  5.4117e-03,
           2.4076e-03,  6.0227e-04]],

        [[ 0.0000e+00,  6.0155e-04,  2.3960e-03,  ...,  5.3532e-03,
           2.3960e-03,  6.0155e-04]],

        [[ 0.0000e+00,  5.9937e-04,  2.3614e-03,  ..

In [7]:
print("_model.stft.forward_basis_buffer", list(jit_model.named_buffers())[0][1].numel(), list(jit_model.named_buffers())[0][1].shape)

_model.stft.forward_basis_buffer 66048 torch.Size([258, 1, 256])


In [8]:
print("stft.forward_basis_buffer", list(model.named_buffers())[0][1].numel(), list(model.named_buffers())[0][1].shape)

stft.forward_basis_buffer 66048 torch.Size([258, 1, 256])


In [10]:
list(model.named_buffers())

[('stft.forward_basis_buffer',
  tensor([[[0., 0., 0.,  ..., 0., 0., 0.]],
  
          [[0., 0., 0.,  ..., 0., 0., 0.]],
  
          [[0., 0., 0.,  ..., 0., 0., 0.]],
  
          ...,
  
          [[0., 0., 0.,  ..., 0., 0., 0.]],
  
          [[0., 0., 0.,  ..., 0., 0., 0.]],
  
          [[0., 0., 0.,  ..., 0., 0., 0.]]]))]

Looks like we have got matching parameters and buffers.

## Load pre-trained weights into our implementation

In [19]:
state_dict = jit_model.state_dict()
state_dict = {k.removeprefix("_model."): v for k,v in state_dict.items() if not k.startswith('_model_8k')}

In [20]:
model.load_state_dict(state_dict)

<All keys matched successfully>

In [21]:
model.stft.forward_basis_buffer

tensor([[[ 0.0000e+00,  1.5059e-04,  6.0227e-04,  ...,  1.3548e-03,
           6.0227e-04,  1.5059e-04]],

        [[ 0.0000e+00,  1.5055e-04,  6.0155e-04,  ...,  1.3511e-03,
           6.0155e-04,  1.5055e-04]],

        [[ 0.0000e+00,  1.5041e-04,  5.9937e-04,  ...,  1.3401e-03,
           5.9937e-04,  1.5041e-04]],

        ...,

        [[ 0.0000e+00, -7.3891e-06,  5.9033e-05,  ...,  1.9879e-04,
          -5.9033e-05,  7.3891e-06]],

        [[ 0.0000e+00, -3.6957e-06,  2.9552e-05,  ...,  9.9663e-05,
          -2.9552e-05,  3.6957e-06]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]]])

In [22]:
list(jit_model.named_buffers())[0][1]

tensor([[[ 0.0000e+00,  1.5059e-04,  6.0227e-04,  ...,  1.3548e-03,
           6.0227e-04,  1.5059e-04]],

        [[ 0.0000e+00,  1.5055e-04,  6.0155e-04,  ...,  1.3511e-03,
           6.0155e-04,  1.5055e-04]],

        [[ 0.0000e+00,  1.5041e-04,  5.9937e-04,  ...,  1.3401e-03,
           5.9937e-04,  1.5041e-04]],

        ...,

        [[ 0.0000e+00, -7.3891e-06,  5.9033e-05,  ...,  1.9879e-04,
          -5.9033e-05,  7.3891e-06]],

        [[ 0.0000e+00, -3.6957e-06,  2.9552e-05,  ...,  9.9663e-05,
          -2.9552e-05,  3.6957e-06]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]]])

## Compare outputs

In [26]:
model.reset_states()
jit_model.reset_states()

for i in range(10):
    # Perform forward pass
    input_tensor = torch.randn(1, 512)  # Sample input (batch_size=10, feature_dim=256)
    output = model(input_tensor, 16000)
    jit_output = jit_model(input_tensor, 16000)

    print(torch.allclose(output, jit_output))

True
True
True
True
True
True
True
True
True
True
