# Audio representations

In [None]:
import os
import math
import IPython
import torch
import torchaudio
from audiomentations import Compose, Normalize, AddGaussianSNR, AddGaussianNoise, AddImpulseResponse, AddShortNoises, AddBackgroundNoise
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.patches
import librosa
import torch.nn.functional as F
import seaborn as sns
import numpy as np
#torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False
torchaudio.set_audio_backend("sox_io")

DATA_PATH = '../data/'
plt.rcParams['figure.dpi'] = 160
n_fft = 512

if not os.path.exists(os.path.join(DATA_PATH, 'LibriSpeech')):
    torchaudio.datasets.LIBRISPEECH(DATA_PATH, url='test-clean', download=True)

In [None]:
test_file_path = os.path.join(DATA_PATH, "LibriSpeech/test-clean/121/127105/121-127105-0036.flac")
signal, sample_rate = torchaudio.load(test_file_path)
print(vars(torchaudio.info(test_file_path)))

signals1 = []
for i in range(1,5):
    test_file_path = os.path.join(DATA_PATH, "LibriSpeech/test-clean/121/127105/121-127105-00{:02d}.flac".format(int(36/i)))
    signal1, _ = torchaudio.load(test_file_path)
    signals1.append(signal1)

signals2 = []
for i in range(1,5):
    test_file_path = os.path.join(DATA_PATH, "LibriSpeech/test-clean/260/123286/260-123286-00{:02d}.flac".format(int(31/i)))
    signal1, _ = torchaudio.load(test_file_path)
    signals2.append(signal1)
    
signals3 = []
for i in range(1,5):
    test_file_path = os.path.join(DATA_PATH, "LibriSpeech/test-clean/7127/75946/7127-75946-00{:02d}.flac".format(int(29/i)))
    signal1, _ = torchaudio.load(test_file_path)
    signals3.append(signal1)

In [None]:
import torch.nn.functional as F
class PreEmphasis(torch.nn.Module):  # pylint: disable=abstract-method
    def __init__(self, coef: float = 0.97):
        super().__init__()
        self.coef = coef
        # make kernel
        # In pytorch, the convolution operation uses cross-correlation. So, filter is flipped.
        self.register_buffer(
            'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        assert len(input.size()) == 2, 'The number of dimensions of input tensor must be 2!'
        # reflect padding to match lengths of in/out
        input = input.unsqueeze(1)
        input = F.pad(input, (1, 0), 'reflect')  # type: ignore
        return F.conv1d(input, self.flipped_filter).squeeze(1)  # type: ignore

n_mels = 128
n_fft = 512

for i, s in enumerate(signals1):
    s = PreEmphasis()(s)
    mel_spectro = torchaudio.transforms.MelSpectrogram(n_mels=n_mels, n_fft=n_fft, win_length=400, hop_length=160, window_fn=torch.hamming_window)(s[:,5000:30000]) + 1e-6
    plt.axis('off')
    plt.imshow(mel_spectro.log2()[0,:,:], cmap='viridis')
    plt.gca().invert_yaxis()
    plt.savefig('spkr1-mel{}.pdf'.format(i), bbox_inches='tight', transparent='True', pad_inches=0)
    
for i, s in enumerate(signals2):
    s = PreEmphasis()(s)
    mel_spectro = torchaudio.transforms.MelSpectrogram(n_mels=n_mels, n_fft=n_fft, win_length=400, hop_length=160, window_fn=torch.hamming_window)(s[:,5000:30000]) + 1e-6
    plt.axis('off')
    plt.imshow(mel_spectro.log2()[0,:,:], cmap='viridis')
    plt.gca().invert_yaxis()
    plt.savefig('spkr2-mel{}.pdf'.format(i), bbox_inches='tight', transparent='True', pad_inches=0)

for i, s in enumerate(signals3):
    s = PreEmphasis()(s)
    mel_spectro = torchaudio.transforms.MelSpectrogram(n_mels=n_mels, n_fft=n_fft, win_length=400, hop_length=160, window_fn=torch.hamming_window)(s[:,5000:30000]) + 1e-6
    plt.axis('off')
    plt.imshow(mel_spectro.log2()[0,:,:], cmap='viridis')
    plt.gca().invert_yaxis()
    plt.savefig('spkr3-mel{}.pdf'.format(i), bbox_inches='tight', transparent='True', pad_inches=0)

In [None]:
plt.plot(signal.t().detach())
plt.axis('off')
plt.savefig('waveform.pdf')
IPython.display.Audio(test_file_path)

## Spectogram

In [None]:
spectro = torchaudio.transforms.Spectrogram(n_fft=n_fft, win_length=400, hop_length=160,
                                            window_fn=torch.hamming_window)(signal)
print(spectro.shape)

plt.imshow(spectro.log2()[0,:,:], cmap='viridis')
plt.ylabel('frequency')
plt.xlabel('time')
plt.gca().invert_yaxis()

## Mel-spectrogram

In [None]:
n_mels = 128
melfreqs = librosa.core.mel_frequencies(fmin=0.0, fmax=sample_rate // 2, n_mels=n_mels)
mel_ticks = matplotlib.ticker.FuncFormatter(lambda x, pos: '{0:g}'.format(int(melfreqs[x])))
mel_spectro = torchaudio.transforms.MelSpectrogram(n_mels=n_mels, n_fft=n_fft)(signal)
print(mel_spectro.shape)


fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, sharey=True, dpi=300)

axes[0].set_title('hann window')
axes[0].imshow(mel_spectro.log2()[0,:,:], cmap='viridis')
axes[0].set_yticks(np.arange(0, mel_spectro.size(1), 30))
axes[0].yaxis.set_major_formatter(mel_ticks)

mel_spectro = torchaudio.transforms.MelSpectrogram(n_mels=n_mels, n_fft=n_fft, window_fn=torch.hamming_window)(signal)
axes[1].set_title('hamming window')
axes[1].imshow(mel_spectro.log2()[0,:,:], cmap='viridis')
axes[1].set_yticks(np.arange(0, mel_spectro.size(1), 30))
axes[1].yaxis.set_major_formatter(mel_ticks)

plt.ylabel('frequency (Hz)')
plt.xlabel('time')
plt.gca().invert_yaxis()
#plt.savefig('melspectro.pdf') 

In [None]:
n_mels = 64
melfreqs = librosa.core.mel_frequencies(fmin=0.0, fmax=sample_rate // 2, n_mels=n_mels)
mel_spectro = torchaudio.transforms.MelSpectrogram(n_mels=n_mels, n_fft=n_fft, win_length=400, hop_length=160, window_fn=torch.hamming_window)(signal)

fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, sharey=False, dpi=300)
plt.ylabel('frequency')
plt.xlabel('time')

axes[0].set_title('64 mels')
axes[0].imshow(mel_spectro.log2()[0,:,:], cmap='viridis')
axes[0].invert_yaxis()

mel_spectro_40 = torchaudio.transforms.MelSpectrogram(n_mels=40, n_fft=n_fft, win_length=400, hop_length=160, window_fn=torch.hamming_window)(signal)
axes[1].set_title('40 mels')
axes[1].imshow(mel_spectro_40.log2()[0,:,:], cmap='viridis')
axes[1].invert_yaxis()


In [None]:
n_mels = 128
mel_spectro = torchaudio.transforms.MelSpectrogram(n_mels=n_mels, n_fft=n_fft, win_length=400, hop_length=160, window_fn=torch.hamming_window)(signal) + 1e-6
plt.axis('off')
plt.imshow(mel_spectro.log2()[0,:,:], cmap='viridis')
plt.gca().invert_yaxis()
plt.savefig('melspectrogram.pdf', bbox_inches='tight', transparent='True', pad_inches=0)

## MelSpectrogram with pre-emphasis and instance normalization

In [None]:
import torch.nn.functional as F
class PreEmphasis(torch.nn.Module):  # pylint: disable=abstract-method
    def __init__(self, coef: float = 0.97):
        super().__init__()
        self.coef = coef
        # make kernel
        # In pytorch, the convolution operation uses cross-correlation. So, filter is flipped.
        self.register_buffer(
            'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        assert len(input.size()) == 2, 'The number of dimensions of input tensor must be 2!'
        # reflect padding to match lengths of in/out
        input = input.unsqueeze(1)
        input = F.pad(input, (1, 0), 'reflect')  # type: ignore
        return F.conv1d(input, self.flipped_filter).squeeze(1)  # type: ignore

melfreqs = librosa.core.mel_frequencies(fmin=0.0, fmax=sample_rate // 2, n_mels=n_mels)

x = signal
x = PreEmphasis()(x)
x = torchaudio.transforms.MelSpectrogram(n_mels=n_mels, n_fft=n_fft, win_length=400, hop_length=160, window_fn=torch.hamming_window)(x)
x = x+1e-6
x = x.log()
x = torch.nn.InstanceNorm1d(n_mels)(x)

y = signal
y = torchaudio.transforms.MelSpectrogram(n_mels=n_mels, n_fft=n_fft, win_length=400, hop_length=160, window_fn=torch.hamming_window)(y)
y = y+1e-6
y = y.log()
y = torch.nn.InstanceNorm1d(n_mels)(y)

z = signal
z = torchaudio.transforms.MelSpectrogram(n_mels=n_mels, n_fft=n_fft, win_length=400, hop_length=160, window_fn=torch.hamming_window)(z)
z = z+1e-6
z = z.log()

fig, axes = plt.subplots(nrows=3, ncols=1, sharex=True, sharey=True, dpi=300)

axes[0].set_title('orig range [{:.4}, {:.4}]'.format(torch.min(z), torch.max(z)))
axes[0].imshow(z[0,:,:], cmap='viridis')

axes[1].set_title('instancenorm range [{:.4}, {:.4}]'.format(torch.min(y), torch.max(y)))
axes[1].imshow(y[0,:,:], cmap='viridis')

axes[2].set_title('preemph range [{:.4}, {:.4}]'.format(torch.min(x), torch.max(x)))
axes[2].imshow(x[0,:,:], cmap='viridis')

plt.gca().invert_yaxis()
fig.tight_layout()

### MelSpectrogram with SpecAugment (without time warping)
https://arxiv.org/abs/1904.08779
https://discuss.pytorch.org/t/does-sparse-image-warp-from-tf-exist-in-pytorch/43514

In [None]:
F=0.25
T=0.15

mel_spectro_spec_aug = torch.nn.Sequential(
    torchaudio.transforms.MelSpectrogram(n_mels=n_mels, n_fft=n_fft, win_length=400, hop_length=160, window_fn=torch.hamming_window),
    torchaudio.transforms.FrequencyMasking(freq_mask_param=int(F * n_mels)),
    torchaudio.transforms.FrequencyMasking(freq_mask_param=int(F * n_mels)),
    torchaudio.transforms.TimeMasking(time_mask_param=int(T * (n_fft // 2 + 1))),
    torchaudio.transforms.TimeMasking(time_mask_param=int(T * (n_fft // 2 + 1)))
)(signal)

plt.imshow(mel_spectro_spec_aug.log2()[0,:,:], cmap='viridis')
plt.ylabel('frequency')
plt.xlabel('time')
plt.gca().invert_yaxis()
#plt.savefig('melspectro-aug.pdf') 

## Mel-frequency cepstral coefficients

In [None]:
mfcc = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mels, log_mels=True)(signal)
print(mfcc.shape)

plt.figure(dpi=190)
ax = sns.heatmap(mfcc[0,:,:], xticklabels=50, yticklabels=5, cbar_kws={'label': 'amplitude'})
ax.set(xlabel='time', ylabel='coefficients')
plt.gca().invert_yaxis()
plt.show()
#ax.get_figure().savefig('mfcc.pdf') 

# Comparisons to signal augmented with gaussian noise

In [None]:
# Augment with gaussian noise
augment = Compose([
    #AddGaussianSNR(min_SNR=0.5, max_SNR=0.5, p=1.0),
    AddShortNoises(os.path.join(DATA_PATH, 'RIRS_NOISES', 'pointsource_noises'), max_snr_in_db=80, p=1.0),
    AddBackgroundNoise(os.path.join(DATA_PATH, 'RIRS_NOISES', 'pointsource_noises'), p=1.0),
    AddImpulseResponse(os.path.join(DATA_PATH, 'RIRS_NOISES', 'simulated_rirs'), leave_length_unchanged=True, p=1.0),
    #Normalize()
])
augmented_np = augment(samples=signal.t().numpy().flatten(), sample_rate=sample_rate)

print(np.min(augmented_np), np.max(augmented_np))

fig, axes = plt.subplots(nrows=1, ncols=2, sharex=True, sharey=True)

axes[0].set_title('original')
axes[0].plot(signal.t().detach())

axes[1].set_title('augmented')
axes[1].plot(augmented_np)

fig.tight_layout()

In [None]:
# Augmented signal
IPython.display.Audio(augmented_np, rate=sample_rate)

In [None]:
# Original signal
IPython.display.Audio(signal, rate=sample_rate)

## Spectogram

In [None]:
augmented = torch.from_numpy(augmented_np).view(1,-1)

fig, axes = plt.subplots(nrows=1, ncols=2, sharex=True, sharey=True)

axes[0].set(xlabel='time', ylabel='frequency')
axes[0].set_title('original')
axes[0].imshow(spectro.log2()[0,:,:], cmap='viridis')
axes[0].invert_yaxis()

spectro_aug = torchaudio.transforms.Spectrogram(n_fft=n_fft)(augmented)
axes[1].set_title('augmented')
axes[1].imshow(spectro_aug.log2()[0,:,:], cmap='viridis')
axes[1].invert_yaxis()
fig.tight_layout()

## Mel-spectogram

In [None]:
mel_spectro_aug = torchaudio.transforms.MelSpectrogram(n_fft=n_fft, n_mels=n_mels, win_length=400, hop_length=160, window_fn=torch.hamming_window)(augmented)

fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, sharey=True)

axes[0].set_title('original')
axes[0].imshow(mel_spectro.log2()[0,:,:], cmap='viridis')
axes[0].invert_yaxis()

axes[1].set(xlabel='time', ylabel='frequency')
axes[1].set_title('augmented')
axes[1].imshow(mel_spectro_aug.log2()[0,:,:], cmap='viridis')
axes[1].invert_yaxis()
fig.tight_layout()

## Mel-frequency cepstral coefficients

In [None]:
mfcc_aug = torchaudio.transforms.MFCC(sample_rate=sample_rate, log_mels=True)(augmented)

plt.figure(dpi=190)
fig, axes = plt.subplots(nrows=1, ncols=2, sharex=True, sharey=True)

sns.heatmap(mfcc[0,:,:], ax=axes[0], xticklabels=50, yticklabels=5)
axes[0].set(xlabel='time', ylabel='coefficients')
axes[0].set_title('original')
axes[0].invert_yaxis()

axes[1].set_title('augmented')
sns.heatmap(mfcc_aug[0,:,:], ax=axes[1], xticklabels=50, yticklabels=5, cbar_kws={'label': 'amplitude'})
axes[1].invert_yaxis()
fig.tight_layout()
plt.show()