In [None]:
%cd /data/codes/apa/train/

import pandas as pd
import os
from glob import glob
import json
from pandarallel import pandarallel
import librosa
import soundfile as sf
from tqdm import tqdm
import torchaudio
import random
import torch
import re

tqdm.pandas()

pandarallel.initialize(nb_workers=8, progress_bar=True)

In [None]:
type2path = {
    10: {
        "json_dir": "/data/metadata/apa-en/marking-data/10",
        "audio_dir": "/data/audio/prep-submission-audio/apa-type-10",
        "metadata_path": "/data/metadata/apa-en/merged-info/info_question_type-10_01082022_18092023.csv",
        "out_metadata_path": "/data/metadata/stt-en/raw/vad-filtered-info_question_type-10_01082022_18092023.csv"
    },
    12: {
        "json_dir": "/data/metadata/apa-en/marking-data/12",
        "audio_dir": "/data/audio/prep-submission-audio/apa-type-12",
        "metadata_path": "/data/metadata/apa-en/merged-info/info_question_type-12_01082022_18092023.csv",
        "out_metadata_path": "/data/metadata/stt-en/raw/vad-filtered-info_question_type-12_01082022_18092023.csv"
    }

}

In [None]:
_type_ = 12
path_dict = type2path[_type_]

in_audio_dir = path_dict["audio_dir"]

data_root_dir = "/data/codes/apa/train/data" 
data_name = os.path.basename(path_dict["metadata_path"]).split(".")[0]
data_dir = os.path.join(data_root_dir, data_name)

out_metadata_path = path_dict["out_metadata_path"]
out_raw_json_path = f'{data_dir}/metadata-raw.jsonl'
out_audio_dir = f'{data_dir}/wav'

In [None]:
hparams = {
    "json_dir": path_dict["json_dir"],
    "audio_dir": path_dict["audio_dir"],
    "metadata_path": path_dict["metadata_path"],
}

metadata = pd.read_csv(hparams["metadata_path"])

metadata = metadata[metadata.word_count == 1.0]
metadata.head(2)

In [None]:
def is_valid_audio(audio_id):
    abs_path = os.path.join(hparams["audio_dir"], f'{audio_id}.wav')
    if not os.path.exists(abs_path):
        return False
    try:
        wav, sr = torchaudio.load(abs_path)
        if sr != 16000:
            return False
    except:
        return False
    
    return True

is_exist =  metadata.id.parallel_apply(is_valid_audio)
print(metadata.shape)
metadata = metadata[is_exist]
print(metadata.shape)

In [None]:
def parse_json_file(id):
    json_path = os.path.join(hparams["json_dir"], f'{id}.json')
    audio_path = os.path.join(hparams["audio_dir"], f'{id}.wav')

    try:
        waveform, sr = librosa.load(audio_path, sr=16000)
        duration = waveform.shape[0] / sr

        with open(json_path, "r") as f:
            content = json.load(f)
        
        segments = []
        for utt_id, raw_utterance in enumerate(content["utterance"]):
            for word_id, word in enumerate(raw_utterance["words"]):
                segment = [
                    word["start_time"],
                    word["end_time"]
                ]
                
                segments.append(segment)

        results = {
            "segments": segments,
            "duration": duration
        }
        return results
    
    except:
        return None

parsed_segments = metadata.id.parallel_apply(parse_json_file)
metadata = metadata[~parsed_segments.isna()]
parsed_segments.head()

In [None]:
parsed_segments = parsed_segments[~parsed_segments.isna()]

In [None]:
def get_silence_segment(segments, duration):
    silence_segments = []
    
    prev = None
    start, end = None, None
    for curr in segments:
        if prev is None:
            start = 0
        else:
            start = prev[1]
        
        end = curr[0]
        prev = curr

        if start == end:
            continue

        segment = [start, end]
        silence_segments.append(segment)

    silence_segments.append([curr[-1], duration])
    return silence_segments

silence_segments = parsed_segments.parallel_apply(
    lambda row: get_silence_segment(
        segments=row["segments"], duration=row["duration"]))

In [None]:
df = pd.DataFrame(
    {
        "id": metadata["id"],
        "segments": silence_segments
    }
)

df.head()

In [None]:
from torch import hub
import torchaudio
import librosa

class Voice_Activity_Detection():
    def __init__(self, sample_rate=16000, device="cuda"):
        self.device = device
        self.sample_rate = sample_rate

        self.model, self.utils = hub.load(
            repo_or_dir="snakers4/silero-vad", 
            model="silero_vad", 
            force_reload=False, 
            onnx=False
        )
        self.model.to(device)

        self.fn_get_speech_timestamps, self.fn_save_audio, \
            self.fn_read_audio, self.VADIterator, self.fn_collect_chunks = self.utils

    @torch.no_grad()
    def get_speech_timestamps(self, segments, threshold=0.7):
        timestamps = self.fn_get_speech_timestamps(
            segments.to(self.device), self.model, threshold=threshold, sampling_rate=self.sample_rate)
        
        return timestamps
    
    @torch.no_grad()
    def is_valid_segment(self, segment, threshold=0.7, min_duration=0.2):
        timestamps = self.fn_get_speech_timestamps(
            segment.to(self.device), self.model, threshold=threshold, sampling_rate=self.sample_rate)
        
        if len(timestamps) == 0:
            return True
        
        is_speech_duration = 0
        for segment in timestamps:
            duration = segment["end"] - segment["start"]

            is_speech_duration += duration

        if (is_speech_duration/self.sample_rate) < min_duration:
            return True
                
        return False
    
vad_model = Voice_Activity_Detection(device="cpu")

In [None]:
def is_valid_sample(id, segments, vad_model):
    audio_path = os.path.join(hparams["audio_dir"], f'{id}.wav')
    waveform, sr = librosa.load(audio_path, sr=16000)

    silences = []
    for start, end in segments:
        segment = waveform[int(start*sr): int(end*sr)]
        segment = torch.from_numpy(segment)

        silences.append(segment)
    
    silences = torch.concat(silences)
    is_valid = vad_model.is_valid_segment(silences, threshold=0.7, min_duration=0.2)

    return is_valid


In [None]:
import torch
import multiprocessing

from concurrent.futures import (
    ProcessPoolExecutor, 
    as_completed
)

torch.set_num_threads(1)
sr = 16000

In [None]:
vad_models = dict()

def init_model():
    pid = multiprocessing.current_process().pid

    vad_models[pid] = Voice_Activity_Detection(device="cpu")

def vad_process(df):
    pid = multiprocessing.current_process().pid

    model = vad_models[pid]
    with torch.no_grad():
        is_valid = df.progress_apply(
            lambda row: is_valid_sample(
                id=row["id"], segments=row["segments"], vad_model=model), axis=1
            )

    return is_valid


In [None]:
num_process = 16
num_sample_per_process = int(df.shape[0] / num_process) + 1

params = []
for i in range(num_process):
    params.append(
        df[i*num_sample_per_process: (i+1)*num_sample_per_process])

In [None]:
futures = []
with ProcessPoolExecutor(max_workers=num_process, initializer=init_model) as executor:
    for param in params:
        futures.append(
            executor.submit(vad_process, param))

results = [finished.result() for finished in as_completed(futures)]
is_valid = pd.concat(results).sort_index()

In [None]:
print(metadata[is_valid].shape)
print(metadata[~is_valid].shape)

In [None]:
metadata = metadata[is_valid]

In [None]:
metadata.to_csv(out_metadata_path, index=None)