In [4]:
import polars as pl
import numpy as np
import random
from typing import Dict, List, Optional, Tuple
from pathlib import Path
import soundfile as sf
import os

In [5]:
def generate_synthetic_audio_data(
    df: pl.DataFrame,
    target_count_per_label: int = 150,
    output_path: str = "balanced_audio_data.parquet",
    save_samples: int = 0,  # Number of samples to save (0 = none)
    sample_rate: int = 44100,  # Default sample rate
    output_dir: str = "synthetic_samples",
    batch_size: int = 50  # Process in batches to reduce memory usage
) -> pl.DataFrame:
    """
    Generate synthetic audio data to balance class distribution, using batch processing
    to reduce memory consumption.
    """
    # Create output directory if it doesn't exist
    if save_samples > 0 and not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"Created directory {output_dir} for synthetic audio samples")
    
    # Get label counts and determine how many samples to generate for each label
    label_counts = df.group_by("Label").agg(pl.count()).sort("count")
    print(f"Original label distribution:\n{label_counts}")
    
    # Create a dictionary for needed synthetic samples per label
    synthetic_needs = {}
    for label, count in zip(label_counts["Label"], label_counts["count"]):
        if count < target_count_per_label:
            synthetic_needs[label] = target_count_per_label - count
    
    # No need to create label_to_examples for the entire dataset at once
    # We'll fetch examples for each label as needed
    
    # First, create a copy of the original dataframe to avoid modification issues
    combined_df = df.clone()
    start_index = int(df["Index"].max()) + 1
    samples_saved = 0
    
    # Process each label that needs synthetic data
    for label, examples_needed in synthetic_needs.items():
        print(f"Generating {examples_needed} synthetic samples for label '{label}'")
        
        # Get source samples for this label
        label_samples = df.filter(pl.col("Label") == label)
        
        # Process in batches to reduce memory usage
        for batch_start in range(0, examples_needed, batch_size):
            batch_end = min(batch_start + batch_size, examples_needed)
            batch_size_actual = batch_end - batch_start
            print(f"Processing batch {batch_start+1}-{batch_end} for label '{label}'")
            
            batch_rows = []
            for i in range(batch_size_actual):
                try:
                    new_index = start_index + batch_start + i
                    synthetic_sample = generate_synthetic_sample(
                        label_samples, 
                        label,
                        new_index
                    )
                    batch_rows.append(synthetic_sample)
                    
                    # Save synthetic samples as WAV files if requested
                    if save_samples > 0 and samples_saved < save_samples:
                        # Get the audio data and filename
                        audio_data = np.array(synthetic_sample["Audio"], dtype=np.float32)
                        filename = synthetic_sample["Filename"]
                        wav_path = os.path.join(output_dir, f"sample_{samples_saved+1}_{filename}")
                        
                        # Save as WAV file
                        try:
                            sf.write(wav_path, audio_data, sample_rate)
                            print(f"Saved synthetic audio sample {samples_saved+1} to {wav_path}")
                            samples_saved += 1
                        except Exception as e:
                            print(f"Error saving audio sample: {e}")
                            
                except Exception as e:
                    print(f"Error generating synthetic sample for label '{label}': {e}")
            
            # Convert batch to dataframe and append to combined dataframe
            if batch_rows:
                try:
                    batch_df = pl.DataFrame(batch_rows)
                    
                    # Make sure synthetic data has the same columns as original
                    for col in df.columns:
                        if col not in batch_df.columns:
                            # Add missing column with default values
                            if col in ["Audio"]:
                                batch_df = batch_df.with_columns(pl.lit([0.0]).repeat(len(batch_df)).alias(col))
                            else:
                                batch_df = batch_df.with_columns(pl.lit(None).alias(col))
                    
                    # Keep only columns from the original dataframe
                    batch_df = batch_df.select(df.columns)
                    
                    # Append to combined dataframe
                    combined_df = pl.concat([combined_df, batch_df])
                    
                    # Periodically save intermediate results to reduce memory pressure
                    if batch_end % (batch_size * 3) == 0 or batch_end == examples_needed:
                        temp_output = f"{output_path}.temp"
                        combined_df.write_parquet(temp_output)
                        print(f"Saved intermediate balanced dataset to {temp_output}")
                        
                        # To reduce memory, we could optionally reload the data
                        # combined_df = pl.read_parquet(temp_output)
                    
                except Exception as e:
                    print(f"Error processing batch: {e}")
                    if 'batch_df' in locals():
                        print(f"Batch width: {len(batch_df.columns)}")
    
    # Save final combined dataset
    combined_df.write_parquet(output_path)
    print(f"Saved balanced dataset to {output_path}")
    
    # Show final distribution
    new_label_counts = combined_df.group_by("Label").agg(pl.count()).sort("count")
    print(f"New label distribution:\n{new_label_counts}")
    
    return combined_df


def generate_synthetic_sample(
    source_samples: pl.DataFrame,
    label: str,
    new_index: int
) -> Dict:
    """
    Generate a single synthetic audio sample by mixing multiple source samples.
    Memory optimized to avoid large array allocations where possible.
    """
    num_samples = len(source_samples)
    num_samples_to_mix = random.randint(2, min(3, num_samples))
    sample_indices = random.sample(range(num_samples), num_samples_to_mix)
    
    # Get selected samples
    selected_samples = [source_samples.row(i, named=True) for i in sample_indices]
    first_sample = selected_samples[0]
    
    # Find the audio column
    audio_column = None
    for col in ["Audio"]:
        if col in first_sample:
            audio_column = col
            break
    
    # Create new filename
    filename = first_sample.get("Filename", "")
    if filename and "_" in filename:
        participant_id = filename.split("_")[0]
    else:
        participant_id = "P00" 
    
    new_filename = f"{participant_id}_{label}_synthetic_{new_index}.wav"
    
    # Process audio data with memory efficiency in mind
    audio_data = first_sample.get(audio_column, [])
    if not isinstance(audio_data, list):
        try:
            audio_data = list(audio_data)
        except:
            audio_data = []
    
    # Find minimum audio length to avoid index errors
    min_audio_length = min(len(sample.get(audio_column, [])) 
                          for sample in selected_samples 
                          if hasattr(sample.get(audio_column, []), '__len__'))
    
    # Pre-allocate the synthetic audio array for better memory efficiency
    synthetic_audio = np.zeros(min_audio_length, dtype=np.float32)
    
    # Generate synthetic audio in chunks for memory efficiency
    chunk_size = 1000  # Process audio in chunks to reduce memory usage
    
    for chunk_start in range(0, min_audio_length, chunk_size):
        chunk_end = min(chunk_start + chunk_size, min_audio_length)
        
        for i in range(chunk_start, chunk_end):
            try:
                values = []
                for sample in selected_samples:
                    sample_audio = sample.get(audio_column, [])
                    if i < len(sample_audio):
                        values.append(sample_audio[i])
                
                if values:
                    weights = np.random.dirichlet(np.ones(len(values)))
                    avg_value = sum(v * w for v, w in zip(values, weights))
                    synthetic_audio[i] = float(avg_value) + np.random.normal(0, 0.02)
                else:
                    synthetic_audio[i] = np.random.normal(0, 0.02)
            except Exception as e:
                print(f"Error processing audio at index {i}: {e}")
                synthetic_audio[i] = 0.0
    
    # Convert back to list for polars compatibility
    synthetic_audio = synthetic_audio.tolist()
    
    # Get metadata from first sample
    id_value = first_sample.get("ID", participant_id)
    
    # Calculate duration
    duration_values = [float(sample.get("Duration", 1.0)) for sample in selected_samples]
    avg_duration = sum(duration_values) / len(duration_values) if duration_values else 1.0
    new_duration = avg_duration * random.uniform(0.9, 1.1)
    
    # Process spectrogram more efficiently
    new_spectrogram = process_spectrogram_efficiently(selected_samples)
    
    # Create result dictionary
    result = {
        "Filename": new_filename,
        "Audio": synthetic_audio,
        "ID": id_value,
        "Label": label,
        "Duration": float(new_duration),
        "Index": int(new_index),
        "Spectrogram": new_spectrogram
    }
    
    # Ensure audio column is set
    result[audio_column] = synthetic_audio
    
    return result


def process_spectrogram_efficiently(selected_samples: List[Dict]) -> List[List[float]]:
    """
    Process spectrograms more efficiently to reduce memory usage.
    """
    first_sample = selected_samples[0]
    base_spectrogram = first_sample.get("Spectrogram", [])
    
    if not base_spectrogram:
        return []
    
    # Get dimensions for new spectrogram
    spec_height = len(base_spectrogram)
    if spec_height == 0:
        return []
    
    # Find a valid row to determine width
    for sample in selected_samples:
        spec = sample.get("Spectrogram", [])
        if spec and len(spec) > 0 and len(spec[0]) > 0:
            first_row_len = len(spec[0])
            break
    else:
        # Default if no valid row found
        first_row_len = 10 if base_spectrogram and len(base_spectrogram) > 0 else 10
    
    # Pre-allocate spectrogram
    new_spectrogram = []
    
    # Process row by row
    for row_idx in range(spec_height):
        available_rows = []
        for sample in selected_samples:
            spec = sample.get("Spectrogram", [])
            if row_idx < len(spec) and spec[row_idx]:
                available_rows.append(spec[row_idx])
        
        if not available_rows:
            if new_spectrogram:
                # Copy last row with some noise
                new_row = [float(v) + np.random.normal(0, 0.1) for v in new_spectrogram[-1]]
            else:
                # Create random row
                new_row = [np.random.normal(0, 0.1) for _ in range(first_row_len)]
        else:
            # Find minimum column length to avoid index errors
            max_col = min(len(row) for row in available_rows if row)
            
            # Create new row
            new_row = []
            for col_idx in range(max_col):
                try:
                    values = [float(row[col_idx]) for row in available_rows if col_idx < len(row)]
                    if values:
                        weights = np.random.dirichlet(np.ones(len(values)))
                        avg_value = sum(v * w for v, w in zip(values, weights))
                        new_row.append(float(avg_value) + np.random.normal(0, 0.05))
                    else:
                        new_row.append(np.random.normal(0, 0.1))
                except Exception:
                    new_row.append(np.random.normal(0, 0.1))
        
        new_spectrogram.append(new_row)
    
    return new_spectrogram


def load_and_balance_audio_data(
    input_path: str,
    target_count: int = 150,
    output_path: str = "balanced_audio_data.parquet",
    save_samples: int = 0,
    sample_rate: int = 44100,
    output_dir: str = "synthetic_samples",
    batch_size: int = 50
) -> pl.DataFrame:
    """
    Load audio data and balance the dataset using memory-efficient batched processing.
    """
    try:
        # Load the data
        print(f"Loading data from {input_path}")
        df = pl.read_parquet(input_path)
        
        # Print column names and sample data for debugging
        print(f"Available columns: {df.columns}")
        print(f"Number of rows: {len(df)}")
        
        # Generate synthetic data and save
        return generate_synthetic_audio_data(
            df, 
            target_count, 
            output_path,
            save_samples=save_samples,
            sample_rate=sample_rate,
            output_dir=output_dir,
            batch_size=batch_size
        )
    except Exception as e:
        print(f"Error in load_and_balance_audio_data: {e}")
        raise

In [9]:
input_file = "../checkpoint2/processed_data.parquet"  # Path to your input data
    
# Set target count per label
target_samples_per_label = 150  # Adjust as needed

# Load, balance, and save
balanced_df = load_and_balance_audio_data(
    input_file, 
    target_samples_per_label,
    "balanced_audio_data.parquet",
    save_samples=0, 
)
    
print(f"Original data shape: {pl.read_parquet(input_file).shape}")
print(f"Balanced data shape: {balanced_df.shape}")

Loading data from /Users/shivam/Documents/ComSci/Classes/CMSC320/cmsc320-final/checkpoint2/processed_data.parquet
Available columns: ['Filename', 'Audio', 'ID', 'Label', 'Duration', 'Index', 'Spectrogram']
Number of rows: 7077
Original label distribution:
shape: (22, 2)
┌──────────────┬───────┐
│ Label        ┆ count │
│ ---          ┆ ---   │
│ str          ┆ u32   │
╞══════════════╪═══════╡
│ greeting     ┆ 3     │
│ hunger       ┆ 4     │
│ tablet       ┆ 7     │
│ glee         ┆ 8     │
│ laugh        ┆ 8     │
│ …            ┆ …     │
│ social       ┆ 634   │
│ dysregulated ┆ 704   │
│ delighted    ┆ 1272  │
│ frustrated   ┆ 1536  │
│ selftalk     ┆ 1885  │
└──────────────┴───────┘
Generating 2 synthetic samples for label 'greeting'
Processing batch 1-2 for label 'greeting'


  label_counts = df.group_by("Label").agg(pl.count()).sort("count")


Saved synthetic audio sample 1 to synthetic_samples/sample_1_P11_greeting_synthetic_782.wav
Saved synthetic audio sample 2 to synthetic_samples/sample_2_P11_greeting_synthetic_783.wav
