In [None]:
import numpy as np
from pydub import AudioSegment
from pydub.silence import split_on_silence, detect_nonsilent
import librosa
import soundfile as sf
import webrtcvad
import io
import warnings
warnings.filterwarnings('ignore')

class AudioChunker:
    def __init__(self, sample_rate=16000, min_chunk_duration=10, max_chunk_duration=30):
        """
        Initialize the audio chunker with specified sample rate and chunk duration bounds.
        Args:
            sample_rate (int): Sample rate of the audio (default: 16000 Hz)
            min_chunk_duration (int): Minimum duration of any chunk in seconds (default: 10)
            max_chunk_duration (int): Maximum duration of any chunk in seconds (default: 30)
        """
        self.sample_rate = sample_rate
        self.min_chunk_duration = min_chunk_duration
        self.max_chunk_duration = max_chunk_duration
        self.vad = webrtcvad.Vad(3)  # Aggressiveness mode 3 (most aggressive)

    def numpy_to_audiosegment(self, audio_data):
        """
        Convert numpy array to AudioSegment.
        Args:
            audio_data (numpy.ndarray): Audio signal (-1.0 to 1.0 float32)
        Returns:
            AudioSegment: Audio data as an AudioSegment object
        """
        # Convert to 16-bit PCM
        audio_int16 = (audio_data * 32768).astype(np.int16)
        
        # Create AudioSegment from raw PCM data
        return AudioSegment(
            audio_int16.tobytes(), 
            frame_rate=self.sample_rate,
            sample_width=2,
            channels=1
        )

    def audiosegment_to_numpy(self, audio_segment):
        """
        Convert AudioSegment to numpy array.
        Args:
            audio_segment (AudioSegment): Audio data as AudioSegment
        Returns:
            numpy.ndarray: Audio signal as float32 numpy array
        """
        # Get raw audio data as numpy array
        samples = np.array(audio_segment.get_array_of_samples())
        
        # Convert to float32
        return samples.astype(np.float32) / 32768.0

    def load_audio(self, file_path):
        """
        Load audio file using librosa and normalize it.
        Args:
            file_path (str): Path to the audio file
        Returns:
            tuple: (audio_data, sample_rate)
        """
        audio_data, sr = librosa.load(file_path, sr=self.sample_rate)
        return audio_data, sr

    def merge_small_chunks(self, chunks, min_duration):
        """
        Merge chunks smaller than min_duration with adjacent chunks.
        Args:
            chunks (list): List of AudioSegment chunks
            min_duration (float): Minimum duration in seconds
        Returns:
            list: Merged chunks as AudioSegment objects
        """
        if not chunks:
            return chunks

        merged = []
        current_chunk = chunks[0]
        
        for next_chunk in chunks[1:]:
            current_duration = len(current_chunk.get_array_of_samples()) / self.sample_rate
            next_duration = len(next_chunk.get_array_of_samples()) / self.sample_rate
            
            # If current chunk is too small, merge it with the next chunk
            if current_duration < min_duration:
                current_chunk = current_chunk + next_chunk
            # If merging would exceed max duration, save current and start new
            elif (current_duration + next_duration) > self.max_chunk_duration:
                merged.append(current_chunk)
                current_chunk = next_chunk
            # If current chunk is already good size, start new chunk
            else:
                merged.append(current_chunk)
                current_chunk = next_chunk
        
        # Don't forget the last chunk
        if current_chunk:
            # If the last chunk is too small and we have previous chunks, merge it with the last merged chunk
            if (len(current_chunk.get_array_of_samples()) / self.sample_rate < min_duration) and merged:
                merged[-1] = merged[-1] + current_chunk
            else:
                merged.append(current_chunk)
        
        return merged

    def chunk_by_silence(self, audio_data, min_silence_len=500, silence_thresh=-40):
        """
        Split audio into chunks based on silence detection.
        Args:
            audio_data (numpy.ndarray): Audio signal
            min_silence_len (int): Minimum length of silence (in ms)
            silence_thresh (int): Silence threshold in dB
        Returns:
            list: List of audio chunks as AudioSegment objects
        """
        # Convert numpy array to AudioSegment
        audio_segment = self.numpy_to_audiosegment(audio_data)

        # Split on silence
        chunks = split_on_silence(
            audio_segment,
            min_silence_len=min_silence_len,
            silence_thresh=silence_thresh,
            keep_silence=300  # Keep 300ms of silence at the start and end
        )

        return chunks

    def chunk_by_vad(self, audio_data, frame_duration=30):
        """
        Split audio into chunks using WebRTC Voice Activity Detection.
        Args:
            audio_data (numpy.ndarray): Audio signal
            frame_duration (int): Duration of each frame in milliseconds (10, 20, or 30)
        Returns:
            list: List of audio chunks as numpy arrays
        """
        # Convert float32 audio to int16
        audio_int16 = (audio_data * 32768).astype(np.int16)
        
        # Calculate frame size
        frame_size = int(self.sample_rate * frame_duration / 1000)
        
        # Pad audio to ensure it's divisible by frame_size
        if len(audio_int16) % frame_size != 0:
            padding = frame_size - (len(audio_int16) % frame_size)
            audio_int16 = np.pad(audio_int16, (0, padding))
        
        # Split audio into frames
        frames = np.array_split(audio_int16, len(audio_int16) // frame_size)
        
        # Detect voice activity in each frame
        is_speech = []
        for frame in frames:
            try:
                if len(frame) == frame_size:  # Only process full-size frames
                    is_speech.append(self.vad.is_speech(frame.tobytes(), self.sample_rate))
                else:
                    is_speech.append(False)
            except:
                is_speech.append(False)
        
        # Group consecutive speech frames into chunks
        chunks = []
        current_chunk = []
        
        for i, speech in enumerate(is_speech):
            if speech:
                current_chunk.extend(frames[i])
            elif current_chunk:
                chunks.append(np.array(current_chunk))
                current_chunk = []
        
        if current_chunk:
            chunks.append(np.array(current_chunk))
        
        # Convert int16 chunks back to float32
        return [chunk.astype(np.float32) / 32768.0 for chunk in chunks]

    def adaptive_chunking(self, audio_data, min_silence_len=500, silence_thresh=-40, vad_frame_duration=30):
        """
        Multi-stage chunking process:
        1. First splits audio based on silence
        2. If any chunk exceeds max_duration, splits it further using VAD
        3. Merges chunks smaller than min_duration
        
        Args:
            audio_data (numpy.ndarray): Audio signal
            min_silence_len (int): Minimum length of silence (in ms)
            silence_thresh (int): Silence threshold in dB
            vad_frame_duration (int): Frame duration for VAD in ms
        Returns:
            list: List of audio chunks as AudioSegment objects
        """
        # First stage: Silence-based chunking
        silence_chunks = self.chunk_by_silence(audio_data, min_silence_len, silence_thresh)
        
        intermediate_chunks = []
        max_samples = self.max_chunk_duration * self.sample_rate
        
        # Process chunks that are too long
        for chunk in silence_chunks:
            chunk_duration = len(chunk.get_array_of_samples()) / self.sample_rate
            
            if chunk_duration > self.max_chunk_duration:
                # Convert AudioSegment to numpy array for VAD processing
                chunk_np = self.audiosegment_to_numpy(chunk)
                
                # Second stage: VAD-based chunking for long segments
                vad_chunks = self.chunk_by_vad(chunk_np, vad_frame_duration)
                
                # Further split VAD chunks if they're still too long
                for vad_chunk in vad_chunks:
                    if len(vad_chunk) > max_samples:
                        # Split into fixed-size chunks if still too long
                        num_subchunks = int(np.ceil(len(vad_chunk) / max_samples))
                        subchunks = np.array_split(vad_chunk, num_subchunks)
                        for subchunk in subchunks:
                            intermediate_chunks.append(self.numpy_to_audiosegment(subchunk))
                    else:
                        intermediate_chunks.append(self.numpy_to_audiosegment(vad_chunk))
            else:
                intermediate_chunks.append(chunk)
        
        # Final stage: Merge chunks that are too small
        final_chunks = self.merge_small_chunks(intermediate_chunks, self.min_chunk_duration)
        
        return final_chunks

    def save_chunks(self, chunks, output_dir, prefix="chunk"):
        """
        Save audio chunks to files.
        Args:
            chunks (list): List of audio chunks
            output_dir (str): Directory to save chunks
            prefix (str): Prefix for chunk filenames
        """
        import os
        os.makedirs(output_dir, exist_ok=True)

        for i, chunk in enumerate(chunks):
            output_path = os.path.join(output_dir, f"{prefix}_{i}.wav")
            chunk_duration = len(chunk.get_array_of_samples()) / self.sample_rate
            if isinstance(chunk, AudioSegment):
                chunk.export(output_path, format="wav")
            else:
                sf.write(output_path, chunk, self.sample_rate)
            print(f"Saved chunk {i}: {chunk_duration:.2f} seconds")

# Example usage
if __name__ == "__main__":
    # Initialize chunker with 10-second minimum and 30-second maximum chunk duration
    chunker = AudioChunker(min_chunk_duration=10, max_chunk_duration=30)

    # Load audio file
    audio_file = "/mnt/data/ashwin/SpeechRAG/speech_retrieval/data/raw_audio/long_news.mp3"
    audio_data, sr = chunker.load_audio(audio_file)

    # Use adaptive chunking (silence + VAD)
    chunks = chunker.adaptive_chunking(
        audio_data,
        min_silence_len=500,      # Minimum silence length for initial splitting
        silence_thresh=-40,       # Silence threshold in dB
        vad_frame_duration=30     # Frame duration for VAD
    )

    # Save the chunks
    chunker.save_chunks(chunks, "output_chunks")