# Notebook #2 for Trimming of the generated audio with VAD

This notebook is designed to extract a segment of audio with just the spoken word in the generated audio recordings (to generate audio recordings, see notebook #1). 

In [None]:
import os
import shutil
import numpy as np
from tqdm import tqdm
from pathlib import Path

import librosa
import soundfile

import torch
from sgvad import SGVAD

In [None]:
class SGVAG_for_trimming(SGVAD):
    def predict(self, wave: str or torch.Tensor):
        if isinstance(wave, str):
            wave = self.load_audio(wave)
            wave = torch.tensor(wave)
        elif isinstance(wave, torch.Tensor):
            wave = torch.tensor(wave)
        else:
            raise ValueError('unvailible typy for "wave"')
        
        wave = wave.reshape(1, -1)
        wave_len = torch.tensor([wave.size(-1)]).reshape(1)
        processed_signal, processed_signal_len = self.preprocessor(
            input_signal=wave, length=wave_len)
        
        with torch.no_grad():
            mu, _ = self.model(audio_signal=processed_signal, 
                               length=processed_signal_len)
            binary_gates = torch.clamp(mu + 0.5, 0.0, 1.0)
            score = binary_gates.sum(dim=1)
        
        return score >= self.cfg.threshold 

## 1. Configs

In [None]:
base_dir = Path("data/Synt-RuSC/raw/")
output_data_path = Path("data/Synt-RuSC/interim/")

# check base_dir path
assert base_dir.is_dir(), 'path is not exist, check "base_dir"'

# create output_data_path path
output_data_path.mkdir(parents=True, exist_ok=True)


target_words = [
    "yes", "no", "up", "down", "left", 
    "right", "on", "off", "stop", "go", 
    "zero", "one", "two", "three", "four", 
    "five", "six", "seven", "eight", "nine",
    "backward", "forward", "follow", "learn", "visual",
    "grief", "dies", "over", "motto", "newer", 
    "rustier", "knock", "exclude", "nails", "discord", 
    "harm", "blow_off", "create", "cry",
]

## 2. Trimming of the generated audio

In [None]:
# init VAD model

sgvad_model = SGVAG_for_trimming.init_from_ckpt()

sgvad_model.cfg.threshold = 1.1
window_in_samples = (sgvad_model.cfg["preprocessor"]["window_stride"] * 
                     sgvad_model.cfg["sample_rate"])

In [None]:
for subset in ["train", "test", "validation"]:
    if not os.path.exists(output_data_path / f"{subset}_good_max_spread"):
        os.mkdir(output_data_path / f"{subset}_good_max_spread")
    
    for word in os.listdir(base_dir / subset):
        if word not in target_words:
            continue
        print(subset, word)
        
        if not os.path.exists(output_data_path / f"{subset}_good_max_spread" / word):
            os.mkdir(output_data_path / f"{subset}_good_max_spread"/ word)
            
        for fName in tqdm(os.listdir(base_dir / subset / word)):
            if os.path.exists(output_data_path / f"{subset}_good_max_spread" / word / fName):
                continue
                
            # load audio
            audio = sgvad_model.load_audio(base_dir / subset / word / fName)
            
            # too long audio filter
            if len(audio) / sgvad_model.cfg.sample_rate > 5:
                continue
                
            # we got a vector preds of dimension T
            # here T is the number of windows that the audio was split into
            preds = sgvad_model.predict(audio)
            for i in range(preds.shape[1]-2):
                if preds[0, i]:
                    if preds[0, i+2]:
                        if not preds[0, i+1]:
                            preds[0, i+1] = True
            
            if preds.sum().sum()==0:
                shutil.copy(
                    base_dir / subset / word / fName, 
                    output_data_path / f"{subset}_bad" / word / fName
                )
                continue
                
            speech_windows = np.array_split(
                np.arange(len(preds[0])), 
                np.where(~preds[0])[0]+1
            )

            # selected the longest continuous sequence of windows 
            # for which the VAD model’s score exceeded the threshold
            speech_windows_lens = [len(x) for x in speech_windows]
            max_window_idx = np.argmax(speech_windows_lens)
            max_window = speech_windows[max_window_idx]
            
            # some additional sequence extension
            if word in ["four", "rustier", "forward", "exclude", "above"]:
                musthave_samples = 850 * sgvad_model.cfg.sample_rate  / 1000
            else:
                musthave_samples = 700 * sgvad_model.cfg.sample_rate  / 1000
            
            start_sample = int(max_window[0] * window_in_samples)
            end_sample = int(max_window[-1] * window_in_samples)
            start_ms = start_sample / sgvad_model.cfg.sample_rate * 1000
            end_ms = end_sample / sgvad_model.cfg.sample_rate * 1000
            
            shift = 1
            added_on_left = False
            while end_sample - start_sample < musthave_samples:
                if start_sample == 0 and end_sample >= len(audio)-1:
                    break
                
                if max_window_idx + shift > len(audio)-1 and max_window_idx - shift < 0:
                    break
                
                if not added_on_left:
                    if max_window_idx - shift >= 0:
                        start_sample = int(speech_windows[max_window_idx - shift][0] * window_in_samples)
                    added_on_left = True
                else:
                    if max_window_idx + shift <= len(audio)-1:
                        end_sample = int(speech_windows[max_window_idx + shift][-1] * window_in_samples)
                    
                    added_on_left = False
                    shift += 1
            
            start_ms = start_sample / sgvad_model.cfg.sample_rate * 1000
            end_ms = end_sample / sgvad_model.cfg.sample_rate * 1000

            # if the resulting sequence was shorter than 400 ms or
            # longer than 1000 ms, the audio was discarded
            if 1000 > (end_ms - start_ms) > 400:
                soundfile.write(
                    output_data_path / f"{subset}_good_max_spread" / word / fName, 
                    audio[start_sample:end_sample], 
                    samplerate=sgvad_model.cfg.sample_rate
                )