# Test the calculation code for normalization statistics

BYOL-A requires the normalization statistics, the average and standard deviation, precomputed using dataset samples.

The implementation for the SUPERB uses running statistics; we test them.

### Results

The online calculation is confirmed to be close enough to the offline calculation, which the original BYOL-A does.

In [2]:
import sys
import numpy as np
import torch
import torchaudio
from pathlib import Path

# load the list of VoxCeleb1 files
vc1files = list(Path('/lab/data/voxceleb1/dev').rglob('*.wav'))

## Calculate normalization statistics *offline*

In [3]:
from byol_a import load_yaml_config, LogMelSpectrogram, RunningNorm

config = load_yaml_config('./config.yaml')
to_logmelspec = LogMelSpectrogram()
normalizer = RunningNorm(epoch_samples=10_000, max_update_epochs=1, axis=[0, 1, 2]) # Use single scalar mean/std values.

In [4]:

lms_list = []
for f in np.random.choice(vc1files, size=10000, replace=False):
    # load .wav data in the same way superb does.
    wav, sr = torchaudio.load(f) # https://github.com/s3prl/s3prl/blob/main/s3prl/downstream/voxceleb1/dataset.py#L106
    # pad in the same way we do in byol_a.py.
    wav = torch.nn.functional.pad(wav, (0, 128000))[..., :128000]
    # convert it to a log-mel spectrogram as a batch of single file.
    lms = to_logmelspec(wav[None, ...])
    lms_list.append(lms)

lms_list = torch.cat(lms_list, dim=2)
lms_list.shape, lms_list.mean(), lms_list.std()

(torch.Size([1, 64, 8010000]), tensor(-8.9265), tensor(4.9163))

## Calculate normalization statistics *online*

In [5]:
runnorm = RunningNorm(epoch_samples=10000, max_update_epochs=1, axis=[0, 1, 2])
for f in np.random.choice(vc1files, size=10000, replace=False):
    # load .wav data in the same way superb does.
    wav, sr = torchaudio.load(f) # https://github.com/s3prl/s3prl/blob/main/s3prl/downstream/voxceleb1/dataset.py#L106
    # pad in the same way we do in byol_a.py.
    wav = torch.nn.functional.pad(wav, (0, 128000))[..., :128000]
    # convert it to a log-mel spectrogram as a batch of single file.
    lms = to_logmelspec(wav[None, ...])
    runnorm(lms)

In [6]:
runnorm.ema_mean(), runnorm.ema_var().sqrt()

(tensor([[[-8.9218]]]), tensor([[[4.9206]]]))