In [None]:
import random
import numpy as np
import os
import torch
import torch.nn.functional as f
import torchaudio
import math
import skimage.io

In [None]:
# function to resample the audio file to given target sampling rate
def _resample_if_necessary(signal, sr, target_sr):
    if sr != target_sr:
        resampler = torchaudio.transforms.Resample(sr, target_sr)
        signal = resampler(signal)
    return signal

In [None]:
#function to cut longer audio files 
def _cut_if_necessary(signal, num_samples):
    if signal.shape[1] > num_samples:
        signal = signal[:, :num_samples]
    return signal

In [None]:
#function to apply zero padding to the audio file
def _pad_if_necessary(signal, num_samples):
    length_signal = signal.shape[1]
    if length_signal < num_samples:
        num_missing_samples = num_samples - length_signal
        if random.random() < 0.5:
            last_dim_padding = (0, num_missing_samples)
            signal = f.pad(signal, last_dim_padding)
        else:
            last_dim_padding = (math.ceil(num_missing_samples / 2), math.floor(num_missing_samples / 2))
            signal = f.pad(signal, last_dim_padding)
    return signal

In [None]:
#function to scale the image to fit a 8 bit integer
def scale_minmax(x, min=0.0, max=1.0):
    x_std = (x - x.min()) / (x.max() - x.min())
    x_scaled = x_std * (max - min) + min

    return x_scaled

In [None]:
if __name__ == "__main__":
    wav_files_path = 'Data/wav_files (3884 files)/'# file path for the audio files
    image_destination = 'Data/dataset/spec_images/'# destination path for the resulting images
    target_sr = 48000 #The target sampling rate
    num_samples = 240000 #The maximum number of samples 
    hop_length = 2048 #the hop length
    f_min = 16 #lower frequency bound
    f_max = 2048 #upper frequency bound
    n_mels = 128 #number of mel filter banks

    for file in os.listdir(wav_files_path):

        signal, sr = torchaudio.load(wav_files_path + file)

        signal = _resample_if_necessary(signal, sr, target_sr)
        signal = _cut_if_necessary(signal, num_samples)
        signal = _pad_if_necessary(signal, num_samples)
        
        # function for calculating the melspectrogram
        mel_spectrogram = torchaudio.transforms.MelSpectrogram(
            sample_rate=target_sr,
            n_fft=hop_length*2,
            hop_length=hop_length,
            f_min=f_min,
            f_max=f_max,
            n_mels=n_mels)
        
        mels = mel_spectrogram(signal)
        mels = mels.squeeze().numpy()
        mels = np.log(mels + 1e-9) #applying logarithmic function to the spectrogram
        img = scale_minmax(mels, 0, 255).astype(np.uint8)
        img = np.flip(img, axis=0)# flipping the spectrogram image
        img = 255 - img  # invert. make black==more energy

        # save as PNG
        skimage.io.imsave(image_destination+file[:-4] + '.png', img)
