## Step 1: Imports

In [1]:
import os
import numpy as np
import librosa
import librosa.display
import tensorflow as tf
from tqdm import tqdm
from sklearn.metrics import classification_report, f1_score

## Step 2: Paths & Parameters

In [2]:
TEST_AUDIO_ROOT = r"E:\InstruNet-AI\data\polyphonic_test_data"
MODEL_PATH = r"E:\InstruNet-AI\saved_models\best_l2_regularized_model.h5"

TARGET_SR = 16000
WINDOW_SEC = 3.0
HOP_SEC = 1.5

N_MELS = 128
TARGET_FRAMES = 126

NUM_CLASSES = 11
GLOBAL_THRESHOLD = 0.25
EPS = 1e-8

In [3]:
test_files = [f for f in os.listdir(TEST_AUDIO_ROOT) if f.endswith(".wav")]
print(f"Found {len(test_files)} test files.")

Found 1573 test files.


## Step 3: Class Definitions

In [4]:
class_names = [
    "cel", "cla", "flu", "gac", "gel",
    "org", "pia", "sax", "tru", "vio", "voi"
]

class_to_id = {c: i for i, c in enumerate(class_names)}

## Step 4: Load Model

In [5]:
model = tf.keras.models.load_model(MODEL_PATH)
model.summary()



## Step 5: Audio Preprocessing Utilities

In [6]:
def stereo_to_mono(audio):
    if audio.ndim == 1:
        return audio
    return np.mean(audio, axis=0)

def peak_normalize(audio):
    peak = np.max(np.abs(audio))
    return audio / peak if peak > 0 else audio

def trim_silence(audio, thresh=0.02):
    idx = np.where(np.abs(audio) > thresh)[0]
    if len(idx) == 0:
        return audio
    return audio[idx[0]: idx[-1]]

def fix_duration(audio, sr=TARGET_SR, duration=WINDOW_SEC):
    target_len = int(sr * duration)
    if len(audio) > target_len:
        return audio[:target_len]
    else:
        return np.pad(audio, (0, target_len - len(audio)), mode="constant")

## Step 6: Log-Mel Spectogram Extraction

In [7]:
def generate_log_mel(audio, sr=TARGET_SR):
    mel = librosa.feature.melspectrogram(
        y=audio,
        sr=sr,
        n_fft=2048,
        hop_length=512,
        win_length=2048,
        window="hann",
        n_mels=N_MELS,
        power=2.0
    )

    mel_db = librosa.power_to_db(mel, ref=np.max)
    mel_db = (mel_db - mel_db.mean()) / (mel_db.std() + EPS)
    return mel_db

In [8]:
def fix_mel_frames(mel, target_frames=TARGET_FRAMES):
    if mel.shape[1] < target_frames:
        mel = np.pad(mel, ((0, 0), (0, target_frames - mel.shape[1])), mode="constant")
    else:
        mel = mel[:, :target_frames]
    
    return mel

In [9]:
def extract_features(y, sr=TARGET_SR):
    y = stereo_to_mono(y)
    y = peak_normalize(y)
    y = trim_silence(y)
    y = fix_duration(y, sr)

    mel = generate_log_mel(y, sr)
    mel = fix_mel_frames(mel)

    return mel

## Step 7: Sliding Window Segmentation

In [10]:
def sliding_windows(y, sr):
    win_len = int(sr * WINDOW_SEC)
    hop_len = int(sr * HOP_SEC)

    for start in range(0, len(y) - win_len + 1, hop_len):
        yield y[start:start + win_len]

## Step 8: Ground Truth Loader

In [11]:
def load_multilabel_gt(txt_path):
    labels = np.zeros(NUM_CLASSES, dtype=np.int32)
    with open(txt_path, encoding="utf-8") as f:
        for line in f:
            inst = line.strip().lower()   # removes \t, \n, spaces
            if inst in class_to_id:
                labels[class_to_id[inst]] = 1
    return labels

## Step 9: Aggregated Prediction (No threshold)

In [12]:
def predict_with_aggregation(audio_path):
    y, _ = librosa.load(audio_path, sr=TARGET_SR, mono=False)
    y = stereo_to_mono(y)

    probs = []
    for window in sliding_windows(y, sr=TARGET_SR):
        mel = extract_features(window)
        mel = mel[np.newaxis, ..., np.newaxis]
        probs.append(model.predict(mel, verbose=0)[0])

    return np.mean(np.array(probs), axis=0)

## Step 10: Collect Aggregated Probs (for threshold learning)

In [13]:
test_files = [f for f in os.listdir(TEST_AUDIO_ROOT) if f.endswith(".wav")]
print("Number of test files:", len(test_files))

y_true_ag = []
y_pred_ag_probs = []

for wav in tqdm(test_files, desc="Aggregated inference (no threshold)"):
    audio_path = os.path.join(TEST_AUDIO_ROOT, wav)
    txt_path = audio_path.replace(".wav", ".txt")

    y_true_ag.append(load_multilabel_gt(txt_path))
    y_pred_ag_probs.append(predict_with_aggregation(audio_path))

y_true_ag = np.stack(y_true_ag)
y_pred_ag_probs = np.stack(y_pred_ag_probs)

Number of test files: 1573


Aggregated inference (no threshold): 100%|███████████████████████████████████████| 1573/1573 [1:08:33<00:00,  2.61s/it]


In [14]:
print("Aggregated probability stats:")
print("Min:", y_pred_ag_probs.min())
print("Max:", y_pred_ag_probs.max())
print("Mean:", y_pred_ag_probs.mean())

Aggregated probability stats:
Min: 7.11291e-05
Max: 0.9993469
Mean: 0.08538698


## Step 11: Learn Per-Class Thresholds

In [15]:
def find_best_threshold(y_true, y_probs):
    thresholds = np.arange(0.01, 0.5, 0.01)
    best_thr, best_f1 = 0.01, 0.0

    for thr in thresholds:
        preds = (y_probs >= thr).astype(int)
        f1 = f1_score(y_true, preds, average="binary", zero_division=0)
        if f1 > best_f1:
            best_f1 = f1
            best_thr = thr

    return best_thr

PER_CLASS_THRESHOLDS = np.array([
    find_best_threshold(y_true_ag[:, i], y_pred_ag_probs[:, i])
    for i in range(NUM_CLASSES)
])

print("Per-class thresholds:", PER_CLASS_THRESHOLDS)

Per-class thresholds: [0.06 0.14 0.13 0.13 0.07 0.06 0.02 0.17 0.08 0.11 0.24]


## Step 12: Apply Per-Class Thresholds

In [16]:
def apply_per_class_threshold(probs, thresholds):
    return (probs >= thresholds).astype(int)

y_pred_ag = apply_per_class_threshold(
    y_pred_ag_probs, PER_CLASS_THRESHOLDS
)

## Step 13: Metrics - With Aggregation

In [17]:
print("WITH AGGREGATION")
print("Micro F1:", f1_score(y_true_ag, y_pred_ag, average="micro"))
print("Macro F1:", f1_score(y_true_ag, y_pred_ag, average="macro"))

print(classification_report(
    y_true_ag, y_pred_ag,
    target_names=class_names,
    zero_division=0
))

WITH AGGREGATION
Micro F1: 0.5686367218282112
Macro F1: 0.4973861542082165
              precision    recall  f1-score   support

         cel       0.09      0.50      0.16        46
         cla       0.25      0.56      0.34        36
         flu       0.28      0.50      0.36        76
         gac       0.62      0.51      0.56       294
         gel       0.47      0.61      0.53       487
         org       0.28      0.58      0.38       191
         pia       0.56      0.66      0.61       620
         sax       0.75      0.78      0.76       232
         tru       0.29      0.68      0.41       111
         vio       0.38      0.65      0.48       135
         voi       0.91      0.85      0.88       483

   micro avg       0.50      0.67      0.57      2711
   macro avg       0.44      0.63      0.50      2711
weighted avg       0.57      0.67      0.60      2711
 samples avg       0.56      0.71      0.58      2711



In [18]:
print("Total positive labels in y_true_ag:", y_true_ag.sum())
print("Per-class support:", y_true_ag.sum(axis=0))

Total positive labels in y_true_ag: 2711
Per-class support: [ 46  36  76 294 487 191 620 232 111 135 483]


## Step 14: Without Aggregation (Baseline)

In [19]:
def predict_without_aggregation(audio_path):
    y, _ = librosa.load(audio_path, sr=TARGET_SR, mono=False)
    y = stereo_to_mono(y)

    probs = []
    for window in sliding_windows(y, sr=TARGET_SR):
        mel = extract_features(window)
        mel = mel[np.newaxis, ..., np.newaxis]
        probs.append(model.predict(mel, verbose=0)[0])

    return np.array(probs)

def track_decision_no_aggregation(segment_probs, thresholds):
    votes = (segment_probs >= thresholds).mean(axis=0)
    return (votes >= 0.5).astype(int)

y_true_na, y_pred_na = [], []

for wav in tqdm(test_files, desc="No aggregation inference"):
    audio_path = os.path.join(TEST_AUDIO_ROOT, wav)
    txt_path = audio_path.replace(".wav", ".txt")

    seg_probs = predict_without_aggregation(audio_path)
    y_pred_na.append(track_decision_no_aggregation(seg_probs, PER_CLASS_THRESHOLDS))
    y_true_na.append(load_multilabel_gt(txt_path))

y_true_na = np.stack(y_true_na)
y_pred_na = np.stack(y_pred_na)

No aggregation inference: 100%|██████████████████████████████████████████████████| 1573/1573 [1:10:42<00:00,  2.70s/it]


## Step 15: Metrics - Without Aggregation

In [20]:
print("WITHOUT AGGREGATION")
print("Micro F1:", f1_score(y_true_na, y_pred_na, average="micro"))
print("Macro F1:", f1_score(y_true_na, y_pred_na, average="macro"))

print(classification_report(
    y_true_na, y_pred_na,
    target_names=class_names,
    zero_division=0
))

WITHOUT AGGREGATION
Micro F1: 0.5566476978789446
Macro F1: 0.4757831854817725
              precision    recall  f1-score   support

         cel       0.07      0.24      0.11        46
         cla       0.26      0.42      0.32        36
         flu       0.30      0.46      0.36        76
         gac       0.64      0.48      0.55       294
         gel       0.49      0.56      0.52       487
         org       0.27      0.45      0.34       191
         pia       0.56      0.57      0.57       620
         sax       0.74      0.75      0.74       232
         tru       0.31      0.55      0.40       111
         vio       0.39      0.56      0.46       135
         voi       0.93      0.81      0.87       483

   micro avg       0.52      0.60      0.56      2711
   macro avg       0.45      0.53      0.48      2711
weighted avg       0.58      0.60      0.58      2711
 samples avg       0.59      0.65      0.56      2711



In [21]:
for wav in test_files[:10]:
    audio_path = os.path.join(TEST_AUDIO_ROOT, wav)
    txt_path = audio_path.replace(".wav", ".txt")

    print("WAV:", wav)
    print("TXT exists:", os.path.exists(txt_path))
    if os.path.exists(txt_path):
        with open(txt_path) as f:
            content = f.read().strip()
            print("TXT content:", repr(content))
    print("-" * 40)

WAV: (02) dont kill the whale-1.wav
TXT exists: True
TXT content: 'gel'
----------------------------------------
WAV: (02) dont kill the whale-11.wav
TXT exists: True
TXT content: 'gel'
----------------------------------------
WAV: (02) dont kill the whale-12.wav
TXT exists: True
TXT content: 'gel\t\nvoi'
----------------------------------------
WAV: (02) dont kill the whale-13.wav
TXT exists: True
TXT content: 'gel\t\nvoi'
----------------------------------------
WAV: (02) dont kill the whale-14.wav
TXT exists: True
TXT content: 'gel\t\nvoi'
----------------------------------------
WAV: (02) dont kill the whale-15.wav
TXT exists: True
TXT content: 'gel\t\npia'
----------------------------------------
WAV: (02) dont kill the whale-2.wav
TXT exists: True
TXT content: 'gel\t\nvoi'
----------------------------------------
WAV: (02) dont kill the whale-3.wav
TXT exists: True
TXT content: 'gel\t\nvoi'
----------------------------------------
WAV: (02) dont kill the whale-4.wav
TXT exists: T