In [None]:
#!conda env update --file environment.yml --prune

Using sorted batching, e.g:
each librispeech wav will be treated separately to others in its same parent folder.
wavs will be sorted by duration

In [1]:
import os, random, glob, ntpath, logging
import numpy as np
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

import torch, torchaudio
import torchaudio.functional as F
from IPython.display import Audio

from tqdm import tqdm

mode = "test"
if os.path.exists("/vol/research/FYP_Leo"):
    speech_src = "/vol/research/FYP_Leo/LibriSpeech/train-clean-100"
    noise_src = "data/noise/train" 
else:
    # speech_src = r"D:\fyp\train-clean-100"
    # noise_src = "data/noise/train" 
    speech_src = r"D:\fyp\test-clean"
    noise_src = "data/noise/test"



def delete_files_in_data(dir: str):
    try:
        if "data" not in dir: raise ValueError

        files = glob.glob(dir + "/**/*.*", recursive=True)
        for f in files:
            os.remove(f)
    except OSError:
        logging.error("Error occured while deleting files")
    except ValueError:
        logging.error("dir must be in the directory `data`")



Basic

In [None]:
speech_list = glob.glob("data/speech/" + mode + "/**/*.wav",recursive=True)
noise_list = glob.glob("data/noise/" + mode + "/**/*.wav",recursive=True)
for speech_file in speech_list:
	snr = random.randint(-1,2) * 5
	noise_file = random.choice(noise_list)	

	speech_name = ntpath.basename(speech_file).removesuffix(".wav")
	noise_name = ntpath.basename(noise_file).removesuffix(".wav")
	output_path = f"{"data/mixed/"+mode}/{speech_name}_{noise_name}_{snr}.wav"

	speech,_ = torchaudio.load("data/speech/train/test.wav")
	noise,_ = torchaudio.load("data/noise/train/ch01.wav")
	noise = noise[:,: speech.shape[1]]
	snr = torch.tensor(snr)
	mixed = F.add_noise(speech, noise, snr)

	torchaudio.save(output_path,mixed,sample_rate=16000, format="wav", encoding="PCM_S")

Sorted

In [2]:
#	`max_samples = None` to use all samples
import math

max_samples = None
max_duration = None#	20*3600*16000	#	(Hours) * 3600*16000
is_descending = True

totalcount = 0

delete_files_in_data("data/speech_ordered/" + mode)
delete_files_in_data("data/mixed/" + mode)

rnd_seed = 42
rnd = random.Random(rnd_seed)

lst :list[list[int,int]] = []

speech_list = glob.glob(f"{speech_src}/**/*.wav",recursive=True)
noise_list = glob.glob(f"{noise_src}/**/*.wav",recursive=True)
print(f"speech_list:{len(speech_list)}")
print(f"noise_list:{len(noise_list)}")
assert((max_samples is None) or max_samples < len(speech_list))

for file in tqdm(speech_list):
	frames = torchaudio.info(file).num_frames
	lst.append([frames,file])

#Sort by num_frames, in descending order
if is_descending is None:
	rnd.shuffle(lst)
else:
	lst.sort(key=lambda x: x[0], reverse=is_descending)


speech = 0
noise = 0
mixed = 0

i = 1
for _, file in tqdm(lst[0:max_samples]):
	speech, _ = torchaudio.load(file, format="wav")
	noise_file = rnd.choice(noise_list)
	noise, _ = torchaudio.load(noise_file, format="wav")
	length_diff = noise.shape[1] - speech.shape[1]
	offset = rnd.randint(0,length_diff)
	noise = noise[:,offset:offset+speech.shape[1]]
	
	snr = torch.tensor([random.randint(-1,2) * 5])
	logger.debug(f"speech:{speech.ndim}, noise:{noise.ndim}, snr:{snr.ndim}")
	mixed = F.add_noise(speech, noise, snr)

	torchaudio.save(f"data/speech_ordered/{mode}/{i}.wav", speech, sample_rate=16000, format="wav", encoding="PCM_S")
	torchaudio.save(f"data/mixed/{mode}/{i}.wav", mixed, sample_rate=16000, format="wav", encoding="PCM_S")
	totalcount += speech.shape[-1]
	if max_duration != None and totalcount > max_duration:
		break
	i += 1

print("Total samples:" + str(totalcount))
print(f"Hours:{math.floor(totalcount/(3600*16000))}, Seconds:{((totalcount/16000) % 3600):.2f}")
Audio(mixed.numpy()[0],rate=16000)
# Audio(noise.numpy()[0],rate=16000)






speech_list:2620
noise_list:32


100%|██████████| 2620/2620 [00:44<00:00, 58.93it/s] 
100%|██████████| 2620/2620 [03:03<00:00, 14.26it/s]


Total samples:311239690
Hours:5, Seconds:1452.48


In [None]:
# print(lst[:10,0])
print(np.array(lst)[:10,0])