## Step 1: Import Libraries

In [1]:
import os
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt
import soundfile as sf
from tqdm import tqdm

## Step 2: Settings for File Paths

In [2]:
IRMAS_MONO_ROOT = r"E:\InstruNet-AI\data\irmas_mono"
OUTPUT_ROOT = r"E:\InstruNet-AI\data\post_preprocessing\irmas_mono"

# ===============================
# AUDIO & FEATURE SETTINGS
# ===============================

TARGET_SR = 16000
FIXED_DURATION = 3.0          # seconds (recommended)
N_MELS = 128

# STFT PARAMETERS (CRITICAL FOR SHARPNESS)
N_FFT = 2048
HOP_LENGTH = 512
WIN_LENGTH = 2048
WINDOW = "hann"

## Step 3: Helper Functions

### (a) Load audio

In [3]:
def load_audio(path):
    try:
        audio, sr = librosa.load(path, sr=None, mono=False)
        return audio, sr
    except Exception as e:
        print(f"[CORRUPTED] {os.path.basename(path)} | {e}")
        return None, None

### (b) Convert stereo → mono

In [4]:
def stereo_to_mono(audio):
    if audio.ndim == 1:
        return audio
    return np.mean(audio, axis=0)

### (c) Resample to 16 kHz

In [5]:
def resample_audio(audio, orig_sr, target_sr=TARGET_SR):
    if orig_sr != target_sr:
        audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
    return audio

### (d) Normalize Audio

In [6]:
def peak_normalize(audio):
    peak = np.max(np.abs(audio))
    return audio / peak if peak > 0 else audio

### (e) Trim Silence (Amplitude Thresholding)

In [7]:
def trim_silence(audio, thresh=0.02): 
    idx = np.where(np.abs(audio) > thresh)[0] 
    if len(idx) == 0: 
        return audio 
    return audio[idx[0]: idx[-1]]

### (f) Pad/Clip to Fixed Duration

In [8]:
def fix_duration(audio, sr=TARGET_SR, duration=FIXED_DURATION):
    target_len = int(sr * duration)
    if len(audio) > target_len:
        return audio[:target_len]
    else:
        return np.pad(audio, (0, target_len - len(audio)), mode="constant")

## Step 4: Generate Sharp Log-Mel Spectrogram

In [9]:
def generate_log_mel(audio, sr=TARGET_SR):
    """
    Returns a sharp, normalized log-mel spectrogram.
    """

    mel = librosa.feature.melspectrogram(
        y=audio,
        sr=sr,
        n_fft=N_FFT,
        hop_length=HOP_LENGTH,
        win_length=WIN_LENGTH,
        window=WINDOW,
        n_mels=N_MELS,
        power=2.0
    )

    # Convert to log scale
    mel_db = librosa.power_to_db(mel, ref=np.max)

    # Per-sample normalization (CRITICAL)
    mel_db = (mel_db - mel_db.mean()) / (mel_db.std() + 1e-8)

    return mel_db

## Step 5: Enforce fixed mel shape 

In [10]:
def fix_mel_frames(mel, target_frames=126):
    """
    Ensure mel spectrogram has exactly target_frames along time axis.
    """
    current_frames = mel.shape[1]

    if current_frames < target_frames:
        pad_width = target_frames - current_frames
        mel = np.pad(
            mel,
            pad_width=((0, 0), (0, pad_width)),
            mode="constant"
        )
    else:
        mel = mel[:, :target_frames]

    return mel

## Step 6: Save .npy and .png (Dual Output)

In [11]:
def save_outputs(mel_db, sr, out_npy_path, out_png_path=None):
    """
    Saves mel spectrogram as .npy (model input)
    and optionally as .png (visual inspection).
    """

    # Save NumPy array (MODEL INPUT)
    np.save(out_npy_path, mel_db)

    # Save PNG (INSPECTION ONLY)
    if out_png_path:
        plt.figure(figsize=(3, 3))
        plt.axis("off")
        librosa.display.specshow(
            mel_db,
            sr=sr,
            hop_length=HOP_LENGTH,
            x_axis=None,
            y_axis=None,
            cmap="magma"
        )
        plt.savefig(out_png_path, bbox_inches="tight", pad_inches=0)
        plt.close()

## Step 7: Preprocess One Audio File (End-to-End)

In [12]:
def preprocess_single_file(audio_path, out_npy_path, out_png_path=None):
    # Load
    audio, sr = load_audio(audio_path)
    if audio is None:
        return False  # corrupted file

    # Stereo → Mono
    audio = stereo_to_mono(audio)

    # Resample
    audio = resample_audio(audio, orig_sr=sr)
    sr = TARGET_SR

    # Normalize
    audio = peak_normalize(audio)

    # Trim silence (AMPLITUDE THRESHOLDING)
    audio = trim_silence(audio, thresh=0.02)

    # Fix duration (time-domain enforcement)
    audio = fix_duration(audio, sr)

    # Generate log-mel
    mel_db = generate_log_mel(audio, sr)

    # Fix mel frames (time–frequency enforcement)
    mel_db = fix_mel_frames(mel_db, target_frames=126)

    # Save outputs
    save_outputs(mel_db, sr, out_npy_path, out_png_path)

    return True

## Step 8: Create Output Directory Structure

In [13]:
splits = ["train", "val", "test"]

for split in splits:
    split_input_dir = os.path.join(IRMAS_MONO_ROOT, split)
    split_output_dir = os.path.join(OUTPUT_ROOT, split)

    for class_name in os.listdir(split_input_dir):
        os.makedirs(os.path.join(split_output_dir, class_name), exist_ok=True)

## Step 9: Run Preprocessing for All Splits (Main Loop)

In [14]:
for split in splits:
    print(f"\n=== Processing split: {split.upper()} ===")

    split_input_dir = os.path.join(IRMAS_MONO_ROOT, split)
    split_output_dir = os.path.join(OUTPUT_ROOT, split)

    for class_name in os.listdir(split_input_dir):
        class_input_dir = os.path.join(split_input_dir, class_name)
        class_output_dir = os.path.join(split_output_dir, class_name)

        wav_files = [f for f in os.listdir(class_input_dir) if f.endswith(".wav")]

        for wav in tqdm(wav_files, desc=f"{split}/{class_name}"):
            audio_path = os.path.join(class_input_dir, wav)

            base_name = wav.replace(".wav", "")
            out_npy = os.path.join(class_output_dir, base_name + ".npy")
            out_png = os.path.join(class_output_dir, base_name + ".png")

            success = preprocess_single_file(
                audio_path=audio_path,
                out_npy_path=out_npy,
                out_png_path=out_png
            )
            
            if not success:
                continue


=== Processing split: TRAIN ===


train/cel: 100%|█████████████████████████████████████████████████████████████████████| 272/272 [00:30<00:00,  8.89it/s]
train/cla: 100%|█████████████████████████████████████████████████████████████████████| 353/353 [00:39<00:00,  8.96it/s]
train/flu: 100%|█████████████████████████████████████████████████████████████████████| 316/316 [00:36<00:00,  8.72it/s]
train/gac: 100%|█████████████████████████████████████████████████████████████████████| 446/446 [00:55<00:00,  8.05it/s]
train/gel: 100%|█████████████████████████████████████████████████████████████████████| 532/532 [01:16<00:00,  6.97it/s]
train/org: 100%|█████████████████████████████████████████████████████████████████████| 477/477 [01:14<00:00,  6.38it/s]
train/pia: 100%|█████████████████████████████████████████████████████████████████████| 505/505 [01:29<00:00,  5.66it/s]
train/sax: 100%|█████████████████████████████████████████████████████████████████████| 438/438 [01:28<00:00,  4.97it/s]
train/tru: 100%|████████████████████████

[CORRUPTED] 075__[voi][dru][pop_roc]2329__2.wav | 


train/voi: 100%|█████████████████████████████████████████████████████████████████████| 544/544 [02:22<00:00,  3.81it/s]



=== Processing split: VAL ===


val/cel: 100%|█████████████████████████████████████████████████████████████████████████| 58/58 [00:16<00:00,  3.41it/s]
val/cla: 100%|█████████████████████████████████████████████████████████████████████████| 76/76 [00:22<00:00,  3.41it/s]
val/flu: 100%|█████████████████████████████████████████████████████████████████████████| 67/67 [00:19<00:00,  3.36it/s]
val/gac: 100%|█████████████████████████████████████████████████████████████████████████| 96/96 [00:31<00:00,  3.04it/s]
val/gel: 100%|███████████████████████████████████████████████████████████████████████| 114/114 [00:34<00:00,  3.32it/s]
val/org: 100%|███████████████████████████████████████████████████████████████████████| 102/102 [00:33<00:00,  3.06it/s]
val/pia: 100%|███████████████████████████████████████████████████████████████████████| 108/108 [00:34<00:00,  3.16it/s]
val/sax: 100%|█████████████████████████████████████████████████████████████████████████| 94/94 [00:31<00:00,  3.00it/s]
val/tru: 100%|██████████████████████████


=== Processing split: TEST ===


test/cel: 100%|████████████████████████████████████████████████████████████████████████| 58/58 [00:20<00:00,  2.83it/s]
test/cla: 100%|████████████████████████████████████████████████████████████████████████| 76/76 [00:27<00:00,  2.75it/s]
test/flu: 100%|████████████████████████████████████████████████████████████████████████| 68/68 [00:25<00:00,  2.72it/s]
test/gac: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:34<00:00,  2.77it/s]
test/gel: 100%|██████████████████████████████████████████████████████████████████████| 114/114 [00:42<00:00,  2.66it/s]
test/org: 100%|██████████████████████████████████████████████████████████████████████| 103/103 [00:41<00:00,  2.48it/s]
test/pia: 100%|██████████████████████████████████████████████████████████████████████| 108/108 [00:50<00:00,  2.13it/s]
test/sax: 100%|████████████████████████████████████████████████████████████████████████| 94/94 [00:38<00:00,  2.43it/s]
test/tru: 100%|█████████████████████████

## Step 10: Final Sanity Checks

In [15]:
# Check random sample shapes
sample_checked = False

for split in splits:
    split_dir = os.path.join(OUTPUT_ROOT, split)
    for cls in os.listdir(split_dir):
        cls_dir = os.path.join(split_dir, cls)
        files = [f for f in os.listdir(cls_dir) if f.endswith(".npy")]
        if files:
            mel = np.load(os.path.join(cls_dir, files[0]))
            print(f"{split}/{cls} sample shape:", mel.shape)
            sample_checked = True
            break
    if sample_checked:
        break

train/cel sample shape: (128, 126)


In [16]:
print(mel.min(), mel.max(), mel.mean(), mel.std())

-2.1632617 3.3033133 1.9678994e-07 0.8637313
