# Working code

In [None]:
from pathlib import Path
import numpy as np
import scipy.io.wavfile
from typing import List, Tuple, Union
import soundfile as sf  # Add this import for better audio file handling

class AudioMixer:
    def __init__(self, click_file: Path):
        self.click_file = click_file
        if not self.click_file.exists():
            raise FileNotFoundError(f"Click file not found: {click_file}")

    @staticmethod
    def load_sound(filename: Path) -> tuple[int, np.ndarray]:
        """Load audio file and return sample rate and data."""
        try:
            # Use soundfile instead of scipy.io.wavfile
            data, sample_rate = sf.read(filename, dtype='int16')
            # Ensure mono
            if len(data.shape) > 1:
                data = data[:, 0]  # Take first channel if stereo
            return sample_rate, data
        except Exception as e:
            raise ValueError(f"Could not load audio file: {e}")

    @staticmethod
    def normalize_audio(audio_data: np.ndarray) -> np.ndarray:
        """Normalize audio to int16 range."""
        if audio_data.dtype != np.int16:
            audio_float = audio_data.astype(np.float32)
            max_val = np.max(np.abs(audio_float))
            if max_val > 0:
                audio_float = audio_float * (32767 / max_val)
            return audio_float.astype(np.int16)
        return audio_data

    @staticmethod
    def write_sound(filename: Path, sample_rate: int, audio_data: np.ndarray) -> None:
        """Write audio data to file with proper normalization."""
        normalized_data = AudioMixer.normalize_audio(audio_data)
        sf.write(filename, normalized_data.T, sample_rate, subtype='PCM_16')

    def mix_sound(self, target: np.ndarray, mix: np.ndarray, 
                position: float, sample_rate: int = 22050, 
                replace: bool = True) -> None:
        """Mix sound at specific position in target."""
        start = int(sample_rate * position)
        end = start + len(mix)
        
        # Trim mix if it would exceed target length
        if end > len(target):
            mix = mix[:len(target) - start]
            end = len(target)
        
        if replace:
            target[start:end] = mix
        else:
            target[start:end] += mix

    def generate_click_train(self, channel: np.ndarray, 
                        positions: List[float], sample_rate: int) -> np.ndarray:
        """Generate click train for given positions."""
        if len(channel.shape) != 1:
            raise ValueError("Input channel must be mono")
            
        # Create a channel the same length as input
        channel2 = np.zeros(len(channel))
        click_sr, click = self.load_sound(self.click_file)
        
        # Ensure click doesn't exceed the end of the audio
        if positions[-1] * sample_rate + len(click) > len(channel):
            positions[-1] = (len(channel) - len(click)) / sample_rate
        
        self.multi_mix(channel2, click, positions, sample_rate=sample_rate)
        return channel2

    def process_file(self, input_file: Path, output_file: Path) -> None:
        """Process audio file with normalization and add clicks at start and end."""
        try:
            sample_rate, channel1 = self.load_sound(input_file)
            channel1 = self.normalize_audio(channel1)
            
            # Calculate the end position in seconds, leaving room for the click
            click_sr, click = self.load_sound(self.click_file)
            end_position = (len(channel1) - len(click)) / sample_rate
            
            # Generate click train with clicks at start and end
            channel2 = self.generate_click_train(channel1, [0.0, end_position], sample_rate)
            
            # Create stereo sound with audio and clicks
            stereo_sound = np.vstack([channel1, channel2])
            
            self.write_sound(output_file, sample_rate, stereo_sound)
        except Exception as e:
            print(f"Error processing {input_file}: {e}")
            
    def multi_mix(self, target: np.ndarray, mix: np.ndarray, 
                 positions: List[float], sample_rate: int = 22050, 
                 replace: bool = True) -> None:
        """Mix sound at multiple positions."""
        for pos in positions:
            self.mix_sound(target, mix, pos, sample_rate, replace)

    def process_directory(self, input_dir: Path, output_dir: Path):
        """Process all audio files in directory."""
        input_dir = Path(input_dir)
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
        
        for audio_file in input_dir.glob('*.*'):
            if audio_file.suffix.lower() in ['.wav', '.mp3']:
                try:
                    output_file = output_dir / f"{audio_file.stem}_click.wav"
                    self.process_file(audio_file, output_file)
                    print(f"Processed: {output_file}")
                except Exception as e:
                    print(f"Error processing {audio_file}: {e}")



# Setup paths
input_dir = Path('./../stim/audio')
output_dir = Path('./../stim/audio_with_clics')
click_file = Path('./../stim/test_clic/click.wav')

# Create and run mixer
mixer = AudioMixer(click_file)
mixer.process_directory(input_dir, output_dir)

# Testing
