In [27]:
import os
import torch
import torchaudio
from torchaudio import transforms
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import math,random
from IPython.display import Audio

In [28]:
#load data
def load_audio_files(path: str, label:str):

    dataset = []
    walker = sorted(str(p) for p in Path(path).glob(f'*.wav'))

    for i, file_path in enumerate(walker):
        path, filename = os.path.split(file_path)
    
        # Load audio
        #waveform, sample_rate = torchaudio.load(file_path)
        dataset.append([file_path, label])
        
    return dataset

trainset_music_good = load_audio_files('./data/good', 'good')
trainset_music_bad = load_audio_files('./data/bad', 'bad')
trainset_music=trainset_music_good+trainset_music_bad


In [29]:

def open_file(audio_file):
    waveform, sample_rate = torchaudio.load(audio_file)
    return (waveform, sample_rate)

#convert stereo to mono to save resources
def toMono(audio):
    waveform,s=audio
    return (waveform[:1,:],s)

def pad_trunc(aud, max_ms):
    sig, sr = aud
    num_rows, sig_len = sig.shape
    max_len = sr//1000 * max_ms

    if (sig_len > max_len):
        # Truncate the signal to the given length
        sig = sig[:,:max_len]

    elif (sig_len < max_len):
        # Length of padding to add at the beginning and end of the signal
        pad_begin_len = random.randint(0, max_len - sig_len)
        pad_end_len = max_len - sig_len - pad_begin_len

        # Pad with 0s
        pad_begin = torch.zeros((num_rows, pad_begin_len))
        pad_end = torch.zeros((num_rows, pad_end_len))

        sig = torch.cat((pad_begin, sig, pad_end), 1)
        
    return (sig, sr)

def time_shift(aud, shift_limit):
    sig,sr = aud
    _, sig_len = sig.shape
    shift_amt = int(random.random() * shift_limit * sig_len)
    return (sig.roll(shift_amt), sr)

def pitch_shift(aud, shift_limit):
    sig,sr = aud
    shift_amt = int(random.random() * shift_limit)
    sig=transforms.PitchShift(sample_rate=sr, n_steps=shift_amt)(sig)
    return (sig,sr)
def speed_shift(aud, shift_limit):
    sig,sr = aud
    shift_amt = int(random.random() * shift_limit)
    sig=transforms.Speed(sig,sr, shift_amt)
    return (sig,sr)

def data_augment(aud):
    aud=pitch_shift(aud,4)
    aud=time_shift(aud,0.1)
    return aud
def stretch(spec):
    rate=int(random.random()*0.2)+0.9
    spec=transforms.TimeStretch(n_freq=512,fixed_rate=rate)(spec)
    return spec

def spectro_gram(aud, n_mels=512, n_fft=4096, hop_len=None):
    sig,sr = aud
    top_db = 80

    # spec has shape [channel, n_mels, time], where channel is mono, stereo etc
    spec = transforms.MelSpectrogram(sr, n_fft=n_fft, hop_length=hop_len, n_mels=n_mels)(sig)

    # Convert to decibels
    spec = transforms.AmplitudeToDB(top_db=top_db)(spec)
    return (spec)


for i in range(0,2):
    for data in trainset_music:
        filename=data[0]
        
        wv=open_file(filename)
        wv=toMono(wv)
        wv=pad_trunc(wv,60000*2)
        if i>0:
            wv=data_augment(wv)
        spec=spectro_gram(wv)
        spec=spec[0].detach().numpy()

        label=data[1]

        _,filename=os.path.split(filename)
        filename,_=os.path.splitext(filename)

        plt.imsave(f'./data/spectrograms/{label}/{filename}_{i}.png',spec,cmap='gray')
    

0
1
