In [1]:
%cd ../../
!pwd

/home/timur.bikbulatov/personal/aa_on_vad
/home/timur.bikbulatov/personal/aa_on_vad


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
import torchaudio

In [3]:
import soundfile as sf
from pathlib import Path
import numpy as np
import torch
import IPython
from functools import partial
play = partial(IPython.display.Audio,
               rate=16000)
playr = IPython.display.Audio

In [9]:
import plotly.graph_objects as go
import numpy as np

def plot(y:list):
    trace = []
    colors = [
        'Blue',
        'Orange',
        'Green',
        'Red',
        'Purple',
        'Magenta',
        'Cyan',
        'Brown',
        'Pink',
        'Lime',
        'Yellow',
        'Teal',
        'Olive',
        'Navy',
        'Maroon',
        'Coral',
        'Gold',
        'Indigo',
        'Turquoise',
        'Lavender',
        'Mint',
        'Silver',
        ]
    for ik, y_ in enumerate(y):
        trace.append(go.Scatter(x=np.arange(len(y_)), y=y_, mode='lines', name=f'arg # {ik + 1}', line=dict(color=colors[ik])))

    # Combining both traces into one figure
    fig = go.Figure(data=trace)

    # Setting the layout
    fig.update_layout(
        title='Two Line Charts on One Plot',
        xaxis_title='X-axis',
        yaxis_title='Y-axis',
        showlegend=True
    )

    # Display the plot
    fig.show()

In [4]:
ds = torchaudio.datasets.LIBRISPEECH(root='datasets/torchlibri',
                                     url ='test-clean',
                                     download=False)

In [119]:
import torchaudio
from torch.utils.data import Dataset
import torchaudio.transforms as T

def remove_silence(audio, sample_rate, energy_threshold=0.02, step_duration=0.01):
    """
    Removes silence from an audio waveform based on an energy threshold.

    Args:
        audio (torch.Tensor): The input audio waveform. Shape: (1, num_samples) or (num_samples,).
        sample_rate (int): The sampling rate of the audio.
        energy_threshold (float): The energy threshold below which audio is considered silence. Default: 0.02.
        step_duration (float): The duration (in seconds) of each step for energy evaluation. Default: 0.01.

    Returns:
        torch.Tensor: The audio waveform with silence removed.
    """
    if len(audio.shape) == 1:
        audio = audio.unsqueeze(0)  # Convert to 2D for consistency (1, num_samples)

    # Calculate step size in samples
    step_size = int(step_duration * sample_rate)

    # Ensure the step size is valid
    if step_size <= 0:
        raise ValueError("Step size must be greater than 0.")

    # Initialize the list to hold non-silent segments
    non_silent_segments = []

    # Process the audio in chunks
    for start in range(0, audio.shape[1], step_size):
        end = min(start + step_size, audio.shape[1])
        chunk = audio[:, start:end]

        # Compute energy of the chunk
        energy = torch.sqrt(torch.mean(chunk**2))

        # Retain chunk if energy exceeds the threshold
        if energy > energy_threshold:
            non_silent_segments.append(chunk)

    # Concatenate non-silent segments
    if non_silent_segments:
        waveform = torch.cat(non_silent_segments, dim=1)
    else:
        # If no audio remains, return a zero tensor
        waveform = torch.tensor([])
    return waveform
    

def vad_forward(waveform: torch.Tensor, vad: callable, sample_rate: int = 16000, step_sec=1, idx: int = 0) -> torch.Tensor:
    """
    Processes a waveform by removing silent segments using a given VAD function.

    Args:
        waveform (torch.Tensor): Input audio waveform. Shape: (n,) or (1, n).
        vad (callable): A function that processes a waveform chunk and removes silence.
        sample_rate (int): The sample rate of the audio. Default: 16000.
        idx (int): Optional identifier for the waveform (used for logging). Default: 0.

    Returns:
        torch.Tensor: The silence-free waveform. Shape: (1, m) or (m,).
    """
    silence_free_waveform = []
    chunk_size = int(sample_rate * step_sec)  # 200ms chunks

    # Ensure waveform is 2D for consistency
    if len(waveform.shape) == 1:
        waveform = waveform.unsqueeze(0)  # Convert to shape (1, num_samples)

    # Process waveform in chunks
    for start in range(0, waveform.shape[-1], chunk_size):
        chunk = waveform[:, start:start + chunk_size]  # Shape: (1, chunk_size)
        try:
            processed_chunk = vad(chunk)
            if processed_chunk.numel() > 0:  # Check if the chunk is non-empty
                silence_free_waveform.append(processed_chunk)
        except Exception as e:
            print(f"VAD processing failed for idx={idx}, chunk={start // chunk_size}: {e}")

    # Concatenate all non-silent chunks
    if silence_free_waveform:
        # Ensure all tensors have the same shape along dim=1 before concatenating
        silence_free_waveform = [chunk for chunk in silence_free_waveform if chunk.numel() > 0]
        waveform = torch.cat(silence_free_waveform, dim=1)
    else:
        waveform = torch.tensor([])  # Return empty tensor if no chunks remain

    return waveform

class LibriSpeechWrapper(Dataset):
    def __init__(self, librispeech_dataset, sample_rate=16000, erase_silence=False, apply_vad=False):
        """
        Initialize the wrapper with a torchaudio.datasets.LIBRISPEECH dataset.

        Args:
            librispeech_dataset: Instance of torchaudio.datasets.LIBRISPEECH
            sample_rate: Sampling rate for VAD (default: 16000)
        """
        self.dataset = librispeech_dataset
        self.apply_vad = apply_vad
        self.erase_silence = erase_silence
        self.sample_rate = self.dataset[0][1]
        if apply_vad:
            self.vad = T.Vad(sample_rate=sample_rate)  # Voice Activity Detection transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        """
        Get a sample from the dataset, remove silence, and return it in the desired dictionary format.

        Args:
            idx: Index of the sample to retrieve.

        Returns:
            A dictionary with keys:
            - 'sample': Audio waveform with silence removed
            - 'sample_rate': Sampling rate of the audio
            - 'Transcript': Transcript of the audio
            - 'speaker_id': Speaker ID
            - 'Chapter ID': Chapter ID
            - 'Utterance ID': Utterance ID
        """
        # Fetch original data
        waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_id = self.dataset[idx]
        
        

        
        
        if self.apply_vad:
            waveform = vad_forward(waveform=waveform, vad=self.vad, sample_rate=self.sample_rate, idx=idx)
        
        if self.erase_silence:
            waveform = remove_silence(waveform, 16000, step_duration=0.02, energy_threshold=0.01)
        if waveform.ndim > 1:
            waveform = waveform[0,:]
        return {
            'sample': waveform,
            'sample_rate': sample_rate,
            'Transcript': transcript,
            'speaker_id': speaker_id,
            'Chapter ID': chapter_id,
            'Utterance ID': utterance_id,
        }

In [120]:
datasets = LibriSpeechWrapper(ds)
datasets[0]

{'sample': tensor([0.0003, 0.0003, 0.0004,  ..., 0.0021, 0.0021, 0.0016]),
 'sample_rate': 16000,
 'Transcript': 'HE HOPED THERE WOULD BE STEW FOR DINNER TURNIPS AND CARROTS AND BRUISED POTATOES AND FAT MUTTON PIECES TO BE LADLED OUT IN THICK PEPPERED FLOUR FATTENED SAUCE',
 'speaker_id': 1089,
 'Chapter ID': 134686,
 'Utterance ID': 0}

In [125]:
ds2 = LibriSpeechWrapper(ds, erase_silence=True)
ds1 = LibriSpeechWrapper(ds, apply_vad=True)

In [126]:
play(datasets[0]['sample'])

In [127]:
ds1[1]['sample'].shape

torch.Size([11200])

In [128]:
play(ds1[0]['sample'])

In [129]:
play(ds2[0]['sample'])

In [None]:
plot([datasets[0]['sample'], ds1[0]['sample'], ds2[0]['sample']])

In [35]:
play(remove_silence(datasets[0]['sample'], 16000, step_duration=0.02, energy_threshold=0.01)[0])