In [1]:
from datasets import load_dataset
from pydub import AudioSegment,silence,playback
from tqdm.notebook import tqdm
import glob
import os
import math
import subprocess
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from collections import namedtuple
from pydub.utils import mediainfo

In [None]:
num_worker = 30

#### Pass 1: Audio Segmentation on Silence

In [36]:
path = '/mnt/disk1/vinhdq/pretrain-wav2vec2/data/segmented/'
os.makedirs(path,exist_ok=True)
min_len = 2
max_len = 20

In [3]:
def segment_audio(filenames,min_len=min_len):
    target_length = min_len * 1000
    for f in tqdm(filenames):
        name = os.path.basename(f)
        sound = AudioSegment.from_file(f,format='wav')
        splits = silence.split_on_silence(sound,min_silence_len=500,silence_thresh=-33)
        output_chunks = [splits[0]]
        
        for chunk in splits[1:]:
            if len(output_chunks[-1]) < target_length:
                output_chunks[-1] += chunk
            else:
                output_chunks.append(chunk)
                
        if len(output_chunks[-1]) < target_length and len(output_chunks) > 1:
            chunk = output_chunks.pop()
            output_chunks[-1] += chunk
        
        for i,chunk in enumerate(output_chunks):
            chunk.export(os.path.join(path,name[:-4] + f'_{i}.wav'),format='wav')

In [4]:
filenames = glob.glob('/mnt/disk1/vinhdq/pretrain-wav2vec2/data/part_4/*.wav')
n_sample_per_worker = math.ceil(len(filenames)/num_worker)

In [5]:
with ProcessPoolExecutor(num_worker) as executor:
    for i in range(num_worker):
        executor.submit(segment_audio,filenames[i*n_sample_per_worker:(i+1)*n_sample_per_worker])

#### Check for too short/long files

In [67]:
CheckResult = namedtuple('CheckResult',['less','more'])

In [68]:
files = glob.glob('data/segmented/*.wav')
n_sample_per_worker = math.ceil(len(files)/num_worker)
indices = list(range(len(files)))

In [70]:
def checking(filename,idc,lower_bound=min_len,upper_bound=max_len):
    less = []
    more = []
    for idx in idc:
        fn = filename[idx]
        dur = float(mediainfo(fn)['duration'])
        
        if dur > upper_bound:
            more.append(fn)
        elif dur < lower_bound:
            less.append(fn)
    return CheckResult(less=less,more=more)

In [71]:
futures = []
with ProcessPoolExecutor(num_worker) as executor:
    for i in range(num_worker):
        futures.append(executor.submit(checking,filename=files,idc=indices[i*n_sample_per_worker:(i+1)*n_sample_per_worker]))

In [72]:
future_res_less = []
future_res_more = []
for future in tqdm(futures):
    res = future.result()
    future_res_less += res.less
    future_res_more += res.more

  0%|          | 0/30 [00:00<?, ?it/s]

In [73]:
print(f'Too short files: {len(future_res_less)}')
print(f'Too long files: {len(future_res_more)}')

Too short files: 630
Too long files: 0


In [19]:
open('need_segmenting.txt','w').write('\n'.join(future_res_more))

268066

In [20]:
open('too_short.txt','w').write('\n'.join(future_res_less))

19154

#### Pass 2: Audio Segmentation on duration

In [45]:
test = AudioSegment.from_wav('data/segmented/ent_3174_8.wav')

In [46]:
len(test)

27833

In [48]:
for f in tqdm(future_res_more):
    fname = os.path.basename(f)
    au = AudioSegment.from_wav(f)

    dv = 2
    while len(au) // dv > max_len * 1000:
        dv += 1
    
    chunk_size = len(au) // dv        
    for i in range(dv):
        au[i*chunk_size:(i+1)*chunk_size].export(os.path.join(path,fname[:-4] + f'_{i}.wav'),format='wav')
    #os.remove(f)

  0%|          | 0/8827 [00:00<?, ?it/s]

##### Double checking

In [61]:
for i in glob.glob('data/segmented/ent_1492_54_*'):
    print(i)

data/segmented/ent_1492_54_0.wav
data/segmented/ent_1492_54_1.wav


In [62]:
AudioSegment.from_wav('data/segmented/ent_1492_54.wav')

In [65]:
AudioSegment.from_wav('data/segmented/ent_1492_54_0.wav')

#### Sampling rate checking

In [None]:
# check sample rate of all files
for file in tqdm(glob.glob('data/part_4/*.wav')):
    info = mediainfo(file)['sample_rate']
    if info != '16000':
        print(f'{os.path.basename(file)} has sample rate {info}')