# Audio I/O Performance Benchmarks: audio_samples vs soundfile

This notebook benchmarks the performance of `audio_samples` Python bindings against `soundfile` for reading and writing WAV files.

**Test Configuration:**
- Sample Rate: 44100 Hz
- Sample Type: f32
- File Durations: 0.1s, 0.5s, 1s, 2s, 5s, 10s, 30s, 60s
- Channel Configurations: Mono and Stereo

**Metrics:**
- Read performance (time to load)
- Write performance (time to save)
- Throughput analysis
- Memory efficiency comparison

## Setup and Imports

In [None]:
import time
import gc
import os
import tempfile

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import soundfile as sf
import audio_samples as aus
from scipy.io import wavfile as scipy_wav
import torchaudio
import scipy

# Define Okabe-Ito colorblind-friendly palette for 4 libraries
OKABE_ITO_COLORS = [
    '#E69F00',  # orange - audio_samples
    '#56B4E9',  # sky blue - soundfile  
    '#009E73',  # bluish green - scipy
    '#F0E442',  # yellow - torchaudio
    '#0072B2',  # blue
    '#D55E00',  # vermillion
    '#CC79A7',  # reddish purple
    '#999999'   # gray
]

# Configure plotting with cleaner styling
plt.rcParams['figure.figsize'] = [12, 9]
plt.rcParams['font.size'] = 20
plt.rcParams['axes.titlesize'] = 24
plt.rcParams['axes.labelsize'] = 22
plt.rcParams['xtick.labelsize'] = 18
plt.rcParams['ytick.labelsize'] = 18
plt.rcParams['legend.fontsize'] = 18
plt.rcParams['axes.titleweight'] = 'bold'
plt.rcParams['axes.labelweight'] = 'bold'
plt.rcParams['xtick.major.width'] = 2  # Normal thickness
plt.rcParams['ytick.major.width'] = 2
plt.rcParams['xtick.minor.width'] = 1
plt.rcParams['ytick.minor.width'] = 1
plt.rcParams['axes.linewidth'] = 2  # Normal thickness, not bold
plt.rcParams['legend.frameon'] = True
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.shadow'] = True
plt.rcParams['grid.alpha'] = 0.3
plt.rcParams['grid.linewidth'] = 1
plt.rcParams['lines.linewidth'] = 4
plt.rcParams['lines.markersize'] = 12

# Set the color palette
sns.set_palette(OKABE_ITO_COLORS)

# Define libraries for benchmarking
LIBRARIES = ['audio_samples', 'soundfile', 'scipy', 'torchaudio']

print(f"audio_samples version: {getattr(aus, '__version__', 'unknown')}")
print(f"soundfile version: {sf.__version__}")
print(f"scipy version: {scipy.__version__}")
print(f"torchaudio version: {torchaudio.__version__}")

print(f"Benchmarking libraries: {LIBRARIES}")

## Configuration and Utility Functions

In [None]:
# Test configuration
SAMPLE_RATE = 44100
SAMPLE_TYPE = np.float32
DURATIONS = [0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0]  # seconds
CHANNELS = [1, 2]  # mono, stereo
BENCHMARK_ITERATIONS = 50  # Increased for statistical significance (was 10)
FREQUENCY = 440.0  # Hz for sine wave test signal

# Create temporary directory for test files
temp_dir = tempfile.mkdtemp(prefix="audio_bench_")
print(f"Test files will be stored in: {temp_dir}")

def cleanup_temp_files():
    """Clean up temporary test files."""
    import shutil
    shutil.rmtree(temp_dir, ignore_errors=True)
    print(f"Cleaned up temporary directory: {temp_dir}")

def get_file_size_mb(filepath: str) -> float:
    """Get file size in MB."""
    return os.path.getsize(filepath) / (1024 * 1024)

def calculate_throughput(file_size_mb: float, time_seconds: float) -> float:
    """Calculate throughput in MB/s."""
    return file_size_mb / time_seconds if time_seconds > 0 else 0.0

def format_time(seconds: float) -> str:
    """Format time with appropriate units."""
    if seconds >= 1.0:
        return f"{seconds:.3f}s"
    elif seconds >= 0.001:
        return f"{seconds*1000:.1f}ms"
    else:
        return f"{seconds*1000000:.1f}Î¼s"


## Test Data Generation

In [None]:
def generate_test_files():
    """Generate test WAV files for all duration/channel combinations."""
    test_files = {}
    
    print("Generating test files...")
    for duration in DURATIONS:
        for channels in CHANNELS:
            # Generate test signal using audio_samples
            audio = aus.generation.sine_wave(
                frequency=FREQUENCY,
                duration_secs=duration,
                sample_rate=SAMPLE_RATE,
                amplitude=0.5
            )
            print(f"Generated sine wave with dtype {audio.dtype}")
            # Convert to stereo if needed
            if channels == 2:
                # Create stereo by duplicating mono signal
                mono_data = audio
                stereo_data = aus.AudioSamples.stack([mono_data, mono_data * 0.8])  # Slight variation for stereo
                audio = stereo_data
            filename = f"test_{duration}s_{channels}ch.wav"
            filepath = os.path.join(temp_dir, filename)
            aus.io.save(filepath, audio)
            
            # Store file info
            test_files[(duration, channels)] = {
                'filepath': filepath,
                'audio_samples': audio,
                'size_mb': get_file_size_mb(filepath),
                'samples_per_channel': int(duration * SAMPLE_RATE)
            }
            
            print(f"  {filename}: {get_file_size_mb(filepath):.2f} MB")
    
    return test_files

# Generate all test files
test_files = generate_test_files()
print(f"\nGenerated {len(test_files)} test files")

## Benchmarking Framework

In [None]:
class BenchmarkTimer:
    """Context manager for timing operations with memory tracking."""

    def __init__(self):
        self.start_time = None
        self.end_time = None

    def __enter__(self):
        gc.collect()  # Clean up before measurement
        self.start_time = time.perf_counter()
        return self

    def __exit__(self, *args):
        self.end_time = time.perf_counter()

    @property
    def elapsed(self):
        return self.end_time - self.start_time

def benchmark_operation(operation, iterations=BENCHMARK_ITERATIONS):
    """Benchmark an operation with multiple iterations."""
    times = []

    for i in range(iterations):
        # Ensure clean state between iterations
        gc.collect()
        
        try:
            with BenchmarkTimer() as timer:
                result = operation()
            
            times.append(timer.elapsed)
            
            # Clean up result to avoid memory accumulation
            del result
            gc.collect()
            
        except Exception as e:
            print(f"Error in benchmark iteration {i}: {e}")
            continue

    if not times:
        return {
            'mean_time': 0.0,
            'std_time': 0.0,
            'min_time': 0.0,
            'max_time': 0.0,
        }

    return {
        'mean_time': np.mean(times),
        'std_time': np.std(times),
        'min_time': np.min(times),
        'max_time': np.max(times),
    }

## Read Performance Benchmarks

In [None]:
def benchmark_read_performance():
    """Benchmark read performance for all libraries."""
    results = []
    
    print("Benchmarking read performance...")
    
    for duration in DURATIONS:
        for channels in CHANNELS:
            file_info = test_files[(duration, channels)]
            filepath = file_info['filepath']
            file_size_mb = file_info['size_mb']
            
            print(f"  Testing {duration}s, {channels}ch...")
            
            # Benchmark audio_samples
            def read_audio_samples():
                return aus.io.read(filepath)
            
            aus_stats = benchmark_operation(read_audio_samples)
            aus_throughput = calculate_throughput(file_size_mb, aus_stats['mean_time'])
            
            # Benchmark soundfile
            def read_soundfile():
                data, sr = sf.read(filepath, dtype=np.float32)
                return data
            
            sf_stats = benchmark_operation(read_soundfile)
            sf_throughput = calculate_throughput(file_size_mb, sf_stats['mean_time'])
            
            # Benchmark scipy
            def read_scipy():
                sr, data = scipy_wav.read(filepath)
                # Convert to float32 if needed
                if data.dtype != np.float32:
                    if data.dtype == np.int16:
                        data = data.astype(np.float32) / 32767.0
                    elif data.dtype == np.int32:
                        data = data.astype(np.float32) / 2147483647.0
                    else:
                        data = data.astype(np.float32)
                return data
            
            scipy_stats = benchmark_operation(read_scipy)
            scipy_throughput = calculate_throughput(file_size_mb, scipy_stats['mean_time'])
            
            # Benchmark torchaudio
            def read_torchaudio():
                waveform, sr = torchaudio.load(filepath)
                return waveform
            
            torch_stats = benchmark_operation(read_torchaudio)
            torch_throughput = calculate_throughput(file_size_mb, torch_stats['mean_time'])
            
            # Store results for all libraries
            results.extend([
                {
                    'operation': 'read',
                    'library': 'audio_samples',
                    'duration': duration,
                    'channels': channels,
                    'file_size_mb': file_size_mb,
                    'mean_time': aus_stats['mean_time'],
                    'std_time': aus_stats['std_time'],
                    'throughput_mb_s': aus_throughput,
                },
                {
                    'operation': 'read',
                    'library': 'soundfile',
                    'duration': duration,
                    'channels': channels,
                    'file_size_mb': file_size_mb,
                    'mean_time': sf_stats['mean_time'],
                    'std_time': sf_stats['std_time'],
                    'throughput_mb_s': sf_throughput,
                },
                {
                    'operation': 'read',
                    'library': 'scipy',
                    'duration': duration,
                    'channels': channels,
                    'file_size_mb': file_size_mb,
                    'mean_time': scipy_stats['mean_time'],
                    'std_time': scipy_stats['std_time'],
                    'throughput_mb_s': scipy_throughput,
                },
                {
                    'operation': 'read',
                    'library': 'torchaudio',
                    'duration': duration,
                    'channels': channels,
                    'file_size_mb': file_size_mb,
                    'mean_time': torch_stats['mean_time'],
                    'std_time': torch_stats['std_time'],
                    'throughput_mb_s': torch_throughput,
                }
            ])
            
            # Show progress
            print(f"    audio_samples: {format_time(aus_stats['mean_time'])}, {aus_throughput:.1f} MB/s")
            print(f"    soundfile:     {format_time(sf_stats['mean_time'])}, {sf_throughput:.1f} MB/s")
            print(f"    scipy:         {format_time(scipy_stats['mean_time'])}, {scipy_throughput:.1f} MB/s")
            print(f"    torchaudio:    {format_time(torch_stats['mean_time'])}, {torch_throughput:.1f} MB/s")
            print()
    
    return results

# Run read benchmarks
read_results = benchmark_read_performance()

## Write Performance Benchmarks

In [None]:
def benchmark_write_performance():
    """Benchmark write performance for all libraries."""
    results = []
    
    print("Benchmarking write performance...")
    
    for duration in DURATIONS:
        for channels in CHANNELS:
            file_info = test_files[(duration, channels)]
            audio_data = file_info['audio_samples']
            file_size_mb = file_info['size_mb']
            
            print(f"  Testing {duration}s, {channels}ch...")
            
            # Prepare data for different libraries
            numpy_data = audio_data.to_numpy()
            if channels == 2:
                # soundfile and scipy expect (samples, channels)
                sf_data = numpy_data.T
                scipy_data = numpy_data.T
                # torchaudio expects (channels, samples)
                torch_data = numpy_data
            else:
                sf_data = numpy_data
                scipy_data = numpy_data
                torch_data = numpy_data.reshape(1, -1)  # Add channel dimension
            
            # Benchmark audio_samples
            def write_audio_samples():
                out_path = os.path.join(temp_dir, f"write_test_aus_{duration}_{channels}.wav")
                aus.io.save(out_path, audio_data)
                return out_path
            
            aus_stats = benchmark_operation(write_audio_samples)
            aus_throughput = calculate_throughput(file_size_mb, aus_stats['mean_time'])
            
            # Benchmark soundfile
            def write_soundfile():
                out_path = os.path.join(temp_dir, f"write_test_sf_{duration}_{channels}.wav")
                sf.write(out_path, sf_data, SAMPLE_RATE, subtype='FLOAT')
                return out_path
            
            sf_stats = benchmark_operation(write_soundfile)
            sf_throughput = calculate_throughput(file_size_mb, sf_stats['mean_time'])
            
            # Benchmark scipy
            def write_scipy():
                out_path = os.path.join(temp_dir, f"write_test_scipy_{duration}_{channels}.wav")
                # Convert float32 to int16 for scipy (it doesn't handle float32 well)
                scipy_int_data = (scipy_data * 32767).astype(np.int16)
                scipy_wav.write(out_path, SAMPLE_RATE, scipy_int_data)
                return out_path
            
            scipy_stats = benchmark_operation(write_scipy)
            scipy_throughput = calculate_throughput(file_size_mb, scipy_stats['mean_time'])
            
            # Benchmark torchaudio
            def write_torchaudio():
                out_path = os.path.join(temp_dir, f"write_test_torch_{duration}_{channels}.wav")
                import torch
                torch_tensor = torch.from_numpy(torch_data.astype(np.float32))
                torchaudio.save(out_path, torch_tensor, SAMPLE_RATE)
                return out_path
            
            torch_stats = benchmark_operation(write_torchaudio)
            torch_throughput = calculate_throughput(file_size_mb, torch_stats['mean_time'])
            
            # Store results for all libraries
            results.extend([
                {
                    'operation': 'write',
                    'library': 'audio_samples',
                    'duration': duration,
                    'channels': channels,
                    'file_size_mb': file_size_mb,
                    'mean_time': aus_stats['mean_time'],
                    'std_time': aus_stats['std_time'],
                    'throughput_mb_s': aus_throughput,
                },
                {
                    'operation': 'write',
                    'library': 'soundfile',
                    'duration': duration,
                    'channels': channels,
                    'file_size_mb': file_size_mb,
                    'mean_time': sf_stats['mean_time'],
                    'std_time': sf_stats['std_time'],
                    'throughput_mb_s': sf_throughput,
                },
                {
                    'operation': 'write',
                    'library': 'scipy',
                    'duration': duration,
                    'channels': channels,
                    'file_size_mb': file_size_mb,
                    'mean_time': scipy_stats['mean_time'],
                    'std_time': scipy_stats['std_time'],
                    'throughput_mb_s': scipy_throughput,
                },
                {
                    'operation': 'write',
                    'library': 'torchaudio',
                    'duration': duration,
                    'channels': channels,
                    'file_size_mb': file_size_mb,
                    'mean_time': torch_stats['mean_time'],
                    'std_time': torch_stats['std_time'],
                    'throughput_mb_s': torch_throughput,
                }
            ])
            
            # Show progress
            print(f"    audio_samples: {format_time(aus_stats['mean_time'])}, {aus_throughput:.1f} MB/s")
            print(f"    soundfile:     {format_time(sf_stats['mean_time'])}, {sf_throughput:.1f} MB/s")
            print(f"    scipy:         {format_time(scipy_stats['mean_time'])}, {scipy_throughput:.1f} MB/s")
            print(f"    torchaudio:    {format_time(torch_stats['mean_time'])}, {torch_throughput:.1f} MB/s")
            print()
    
    return results

# Run write benchmarks
write_results = benchmark_write_performance()

## Results Analysis

In [None]:
# Combine all results
all_results = read_results + write_results
df = pd.DataFrame(all_results)

print("Benchmark Results Summary:")
print("=" * 60)

# Overall statistics
for operation in ['read', 'write']:
    print(f"\n{operation.upper()} Performance:")
    op_data = df[df['operation'] == operation]
    
    for lib in LIBRARIES:
        lib_data = op_data[op_data['library'] == lib]
        avg_time = lib_data['mean_time'].mean()
        avg_throughput = lib_data['throughput_mb_s'].mean()
        
        print(f"  {lib:12}: {format_time(avg_time):>8} avg, {avg_throughput:>6.1f} MB/s avg")

df.to_csv("benchmark_results.csv", index=False)

# Calculate speedup ratios relative to audio_samples as baseline
speedup_data = []
for operation in ['read', 'write']:
    for duration in DURATIONS:
        for channels in CHANNELS:
            subset = df[
                (df['operation'] == operation) & 
                (df['duration'] == duration) & 
                (df['channels'] == channels)
            ]
            
            if len(subset) >= 2:
                # Get audio_samples baseline
                aus_time = subset[subset['library'] == 'audio_samples']['mean_time']
                if len(aus_time) > 0:
                    aus_baseline = aus_time.iloc[0]
                    
                    for lib in LIBRARIES:
                        lib_time = subset[subset['library'] == lib]['mean_time']
                        if len(lib_time) > 0:
                            speedup = aus_baseline / lib_time.iloc[0]  # Higher is better
                            
                            speedup_data.append({
                                'operation': operation,
                                'duration': duration,
                                'channels': channels,
                                'library': lib,
                                'speedup': speedup,
                                'winner': lib if speedup > 1.0 else 'audio_samples'
                            })


speedup_df = pd.DataFrame(speedup_data)

print("\nSpeedup Analysis (relative to audio_samples baseline):")
print("Values > 1.0 mean the library is faster than audio_samples")
if len(speedup_df) > 0:
    summary = speedup_df.groupby(['operation', 'library'])['speedup'].agg(['mean', 'min', 'max']).round(2)
    print(summary)
else:
    print("No speedup data available")

## Performance Visualizations

In [None]:
# Define consistent color and marker mapping for all plots
color_map = {
    'audio_samples': OKABE_ITO_COLORS[0],  # orange
    'soundfile': OKABE_ITO_COLORS[1],      # sky blue
    'scipy': OKABE_ITO_COLORS[2],          # bluish green
    'torchaudio': OKABE_ITO_COLORS[3],     # yellow
}

marker_map = {
    1: 'o',   # circle for mono
    2: 's'    # square for stereo
}

# Create output directory for figures
fig_dir = os.path.join(temp_dir, "figures")
os.makedirs(fig_dir, exist_ok=True)

def save_figure(fig, name, dpi=300):
    """Save figure as both PDF and PNG with high quality."""
    pdf_path = os.path.join(fig_dir, f"{name}.pdf")
    png_path = os.path.join(fig_dir, f"{name}.png")
    
    fig.savefig(pdf_path, format='pdf', dpi=dpi, bbox_inches='tight', 
               facecolor='white', edgecolor='none')
    fig.savefig(png_path, format='png', dpi=dpi, bbox_inches='tight',
               facecolor='white', edgecolor='none')
    
    print(f"Saved: {pdf_path}")
    print(f"Saved: {png_path}")

def plot_read_performance(df):
    """Plot read performance vs file duration."""
    fig, ax = plt.subplots(figsize=(12, 9))
    
    read_data = df[df['operation'] == 'read']
    for channels in CHANNELS:
        for lib in LIBRARIES:
            subset = read_data[(read_data['channels'] == channels) & (read_data['library'] == lib)]
            if len(subset) > 0:
                channel_name = 'mono' if channels == 1 else 'stereo'
                ax.plot(subset['duration'], subset['mean_time'],
                       marker=marker_map[channels], label=f"{lib} ({channel_name})",
                       color=color_map[lib], linewidth=4, markersize=12)

    ax.set_xlabel('File Duration (seconds)', fontweight='bold')
    ax.set_ylabel('Read Time (seconds)', fontweight='bold')
    ax.set_title('Read Performance vs File Duration', fontweight='bold')
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.grid(True, alpha=0.3)
    
    legend = ax.legend(bbox_to_anchor=(0.5, -0.12), loc='upper center', ncol=2)
    legend.get_frame().set_facecolor('white')
    legend.get_frame().set_alpha(0.9)
    
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.15)
    save_figure(fig, "read_performance")
    plt.show()
    return fig

def plot_write_performance(df):
    """Plot write performance vs file duration."""
    fig, ax = plt.subplots(figsize=(12, 9))
    
    write_data = df[df['operation'] == 'write']
    for channels in CHANNELS:
        for lib in LIBRARIES:
            subset = write_data[(write_data['channels'] == channels) & (write_data['library'] == lib)]
            if len(subset) > 0:
                channel_name = 'mono' if channels == 1 else 'stereo'
                ax.plot(subset['duration'], subset['mean_time'],
                       marker=marker_map[channels], label=f"{lib} ({channel_name})",
                       color=color_map[lib], linewidth=4, markersize=12)

    ax.set_xlabel('File Duration (seconds)', fontweight='bold')
    ax.set_ylabel('Write Time (seconds)', fontweight='bold')
    ax.set_title('Write Performance vs File Duration', fontweight='bold')
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.grid(True, alpha=0.3)
    
    legend = ax.legend(bbox_to_anchor=(0.5, -0.12), loc='upper center', ncol=2)
    legend.get_frame().set_facecolor('white')
    legend.get_frame().set_alpha(0.9)
    
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.15)
    save_figure(fig, "write_performance")
    plt.show()
    return fig

def plot_throughput_comparison(df):
    """Plot average throughput comparison."""
    fig, ax = plt.subplots(figsize=(12, 9))
    
    throughput_data = df.groupby(['operation', 'library'])['throughput_mb_s'].mean().reset_index()
    palette = [color_map[lib] for lib in LIBRARIES]
    sns.barplot(data=throughput_data, x='operation', y='throughput_mb_s', hue='library',
               ax=ax, palette=palette)
    
    ax.set_title('Average Throughput Comparison', fontweight='bold')
    ax.set_ylabel('Throughput (MB/s)', fontweight='bold')
    ax.set_xlabel('Operation', fontweight='bold')
    ax.tick_params(axis='both', which='major', labelsize=18, width=2)
    ax.grid(True, alpha=0.3)
    
    legend = ax.legend(bbox_to_anchor=(0.5, -0.12), loc='upper center', ncol=2)
    legend.get_frame().set_facecolor('white')
    legend.get_frame().set_alpha(0.9)
    
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.15)
    save_figure(fig, "throughput_comparison")
    plt.show()
    return fig

# Generate all individual plots
print("Generating individual performance plots...")
read_fig = plot_read_performance(df)
write_fig = plot_write_performance(df)
throughput_fig = plot_throughput_comparison(df)

## Cleanup

In [None]:
# Uncomment the line below to clean up temporary files
# cleanup_temp_files()

print("Benchmark complete!")
print(f"Temporary files are preserved in: {temp_dir}")
print("Uncomment the cleanup_temp_files() call above to remove them.")