In [2]:
import os
import torch
import torchaudio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [9]:
base_dir = './fluent_speech/wavs/speakers/'
wav_paths = os.listdir(base_dir)
channels, rates, samples = [], [], []
for wav_path in wav_paths:
    for wav in os.listdir(os.path.join(base_dir, wav_path)):
        sample, rate = torchaudio.load(os.path.join(base_dir, wav_path, wav))
        rates.append(int(rate))
        channels.append(int(sample.shape[0]))
        samples.append(int(sample.shape[1]))
df = pd.DataFrame({
    'rates': rates,
    'channels': channels,
    'samples': samples
    })
df

Unnamed: 0,rates,channels,samples
0,16000,1,28235
1,16000,1,28235
2,16000,1,28235
3,16000,1,31208
4,16000,1,20805
...,...,...,...
30038,16000,1,31403
30039,16000,1,39595
30040,16000,1,39595
30041,16000,1,38229


In [10]:
df.rates.value_counts(), df.channels.value_counts()

(16000    30043
 Name: rates, dtype: int64,
 1    30043
 Name: channels, dtype: int64)

In [11]:
df.samples.mean()/df.rates.mean(), df.samples[df.rates == 48000].mean() / df.rates[df.rates == 48000].mean()

(2.305681005142629, nan)

In [None]:
def rechannel(aud, new_ch):
    sig, sr = aud
    if (sig.shape[0] == new_ch):
        return sig, sr
    if (new_ch == 1):
        resig = sig[:1, :]
    else:
        resig = torch.cat([sig, sig])
    return (resig, sr)

In [5]:
def resample(aud, new_sr):
    sig, sr = aud
    if (sr == new_sr):
        return sig, sr
    resampled_ch1 = torchaudio.transforms.Resample(rate, new_sr)(sig[:1,:])
    resampled_ch2 = torchaudio.transforms.Resample(rate, new_sr)(sig[1:,:])
    resampled_sig = torch.cat([resampled_ch1, resampled_ch2])
    return (resampled_sig, new_sr)

In [6]:
def resize(aud, max_ms):
    sig, sr = aud
    num_ch, sig_len = sig.shape
    max_len = sr // 1000 * max_ms
    if sig_len > max_len:
        sig = sig[:,:max_len]
    elif sig_len < max_len:
        pad_begin_len = np.random.randint(0, max_len - sig_len)
        pad_end_len = max_len - sig_len - pad_begin_len
    pad_begin = torch.zeros((num_ch, pad_begin_len))
    pad_end = torch.zeros((num_ch, pad_end_len))
    sig = torch.cat((pad_begin, sig, pad_end), 1)
    return (sig, sr)

In [7]:
def time_shift(aud, shift_limit):
    sig, sr = aud
    sig_len = sig.shape[1]
    shift_amt = int(np.random.random() * shift_limit * sig_len)
    return (sig.roll(shift_amt), sr)

In [8]:
def mel_spectrogram(aud, n_mels=64, n_fft=1024, hop_len=None):
    sig, sr = aud
    top_db = 80
    # shape [channel, n_mels, time]
    spec = torchaudio.transforms.MelSpectrogram(sr, n_fft=n_fft, hop_length=hop_len, n_mels=n_mels)(sig)
    spec = torchaudio.transforms.AmplitudeToDB(top_db=top_db)(spec)
    return spec

In [None]:
def spectro_augment(spec, max_mask_pct=0.1, n_freq_masks=1, n_time_masks=1):
    _, n_mels, n_steps = spec.shape
    mask_value = spec.mean()
    aug_spec = spec
    freq_mask_param = max_mask_pct * n_mels
    for _ in range(n_freq_masks):
        aug_spec = torch.transforms.FrequencyMasking(freq_mask_param)(aug_spec, mask_value)
    time_mask_param = max_mask_pct * n_steps
    for _ in range(n_time_masks):
        aug_spec = torch.transforms.TimeMasking(time_mask_param)(aug_spec, mask_value)
    return aug_spec

In [None]:
class AudioDataSet(torch.utils.data.Dataset):

    def __init__(self, df_path, data_path):
        self.df = pd.read_csv(df_path)
        self.data_path = data_path
        self.duration = 2300
        self.sr = 16000
        self.channel = 1
        self.shift_pct = 0.4

    def __len__(self):
        return len(self.df)    

    def __getitem__(self, idx):
        audio_file = self.data_path + self.df.loc[idx, 'relative_path']
        class_id = self.df.loc[idx, 'classID']
        raw_aud = torchaudio.load(audio_file)
        resr_aud = resample(raw_aud, self.sr)
        rech_aud = rechannel(resr_aud, self.channel)
        resz_aud = resize(rech_aud, self.duration)
        shft_aud = time_shift(resz_aud, self.shift_pct)
        raw_spec = mel_spectrogram(shft_aud, n_mels=64, n_fft=1024, hop_len=None)
        aug_spec = spectro_augment(raw_spec, max_mask_pct=0.1, n_freq_masks=2, n_time_masks=2)
        return aug_spec, class_id