## Step 1: Imports

In [1]:
import os
import numpy as np
import librosa
import tensorflow as tf
from tqdm import tqdm
from sklearn.metrics import classification_report, f1_score

## Step 2: Paths & Constants (MATCH TRAINING)

In [2]:
# Paths
TEST_AUDIO_ROOT = r"E:\InstruNet-AI\data\polyphonic_test_data"
MODEL_PATH = r"E:\InstruNet-AI\saved_models\best_baseline_regularized.h5"

# Audio / feature parameters (LOCKED)
TARGET_SR = 16000
WINDOW_SEC = 3.0
HOP_SEC = 1.5
N_MELS = 128
TARGET_FRAMES = 126
NUM_CLASSES = 11
EPS = 1e-8
GLOBAL_THRESHOLD = 0.25

## Step 3: Class Mapping

In [3]:
class_names = [
    "cel", "cla", "flu", "gac", "gel",
    "org", "pia", "sax", "tru", "vio", "voi"
]

class_to_id = {c: i for i, c in enumerate(class_names)}
id_to_class = {i: c for c, i in class_to_id.items()}

## Step 4: Load Trained Model

In [4]:
model = tf.keras.models.load_model(MODEL_PATH)
model.summary()



## Step 5: Preprocessing Utilities (EXACT FROM TRAINING)

In [5]:
def stereo_to_mono(audio):
    if audio.ndim == 1:
        return audio
    return np.mean(audio, axis=0)

def peak_normalize(audio):
    peak = np.max(np.abs(audio))
    return audio / peak if peak > 0 else audio

def trim_silence(audio, thresh=0.02):
    idx = np.where(np.abs(audio) > thresh)[0]
    if len(idx) == 0:
        return audio
    return audio[idx[0]: idx[-1]]

def fix_duration(audio, sr=TARGET_SR, duration=WINDOW_SEC):
    target_len = int(sr * duration)
    if len(audio) > target_len:
        return audio[:target_len]
    else:
        return np.pad(audio, (0, target_len - len(audio)), mode="constant")

## Step 6: Log-Mel Spectrogram

In [6]:
def generate_log_mel(audio, sr=TARGET_SR):
    mel = librosa.feature.melspectrogram(
        y=audio,
        sr=sr,
        n_fft=2048,
        hop_length=512,
        win_length=2048,
        window="hann",
        n_mels=N_MELS,
        power=2.0
    )

    mel_db = librosa.power_to_db(mel, ref=np.max)
    mel_db = (mel_db - mel_db.mean()) / (mel_db.std() + EPS)
    return mel_db

## Step 7: Frame Alignment

In [7]:
def fix_mel_frames(mel, target_frames=TARGET_FRAMES):
    if mel.shape[1] < target_frames:
        mel = np.pad(
            mel,
            ((0, 0), (0, target_frames - mel.shape[1])),
            mode="constant"
        )
    else:
        mel = mel[:, :target_frames]
    return mel

## Step 8: Final Feature Extraction

In [8]:
def extract_features(y, sr=TARGET_SR):
    y = stereo_to_mono(y)
    y = peak_normalize(y)
    y = trim_silence(y, thresh=0.02)
    y = fix_duration(y, sr)

    mel = generate_log_mel(y, sr)
    mel = fix_mel_frames(mel, TARGET_FRAMES)

    return mel

## Step 9: Sliding Window Generator

In [9]:
def sliding_windows(y, sr):
    win_len = int(sr * WINDOW_SEC)
    hop_len = int(sr * HOP_SEC)

    for start in range(0, len(y) - win_len + 1, hop_len):
        yield y[start:start + win_len]

## Step 10: Track-Level Prediction (Mean Aggregation)

In [10]:
def predict_track(audio_path):
    y, sr = librosa.load(audio_path, sr=TARGET_SR, mono=False)
    y = stereo_to_mono(y)

    window_probs = []

    for window in sliding_windows(y, sr):
        mel = extract_features(window, sr)
        mel = mel[..., np.newaxis]
        mel = mel[np.newaxis, ...]

        probs = model.predict(mel, verbose=0)[0]
        window_probs.append(probs)

    window_probs = np.array(window_probs)
    return window_probs.mean(axis=0)

## Step 11: Load Multi-Label Ground Truth

In [11]:
def load_multilabel_gt(txt_path):
    labels = np.zeros(NUM_CLASSES)
    with open(txt_path, "r") as f:
        instruments = f.read().strip().split("\n")
        for inst in instruments:
            if inst in class_to_id:
                labels[class_to_id[inst]] = 1
    return labels

## Step 12: Global Thresholding (NO HARD-CODING PER CLASS)

In [12]:
def apply_global_threshold(probs, threshold=GLOBAL_THRESHOLD):
    return (probs >= threshold).astype(int)

## Step 13: Collect Test Files

In [13]:
test_files = [f for f in os.listdir(TEST_AUDIO_ROOT) if f.endswith(".wav")]
print("Number of test files:", len(test_files))

Number of test files: 1573


## Step 14: Run Test Evaluation

In [14]:
y_test_true = []
y_test_pred = []

for wav in tqdm(test_files, desc="Test inference"):
    audio_path = os.path.join(TEST_AUDIO_ROOT, wav)
    txt_path = audio_path.replace(".wav", ".txt")

    gt = load_multilabel_gt(txt_path)
    probs = predict_track(audio_path)
    pred = apply_global_threshold(probs)

    y_test_true.append(gt)
    y_test_pred.append(pred)

y_test_true = np.stack(y_test_true, axis=0)
y_test_pred = np.stack(y_test_pred, axis=0)

Test inference: 100%|██████████████████████████████████████████████████████████████| 1573/1573 [54:50<00:00,  2.09s/it]


## Step 15: Final Metrics

In [15]:
print("Global Threshold:", GLOBAL_THRESHOLD)

print("Micro F1 :", f1_score(y_test_true, y_test_pred, average="micro"))
print("Macro F1 :", f1_score(y_test_true, y_test_pred, average="macro"))

print(classification_report(
    y_test_true,
    y_test_pred,
    target_names=class_names,
    zero_division=0
))

Global Threshold: 0.25
Micro F1 : 0.5603803486529318
Macro F1 : 0.35965740293107323
              precision    recall  f1-score   support

         cel       0.00      0.00      0.00         0
         cla       0.02      0.04      0.03        23
         flu       0.00      0.00      0.00         4
         gac       0.47      0.91      0.62       119
         gel       0.24      0.34      0.29       145
         org       0.21      0.34      0.26        77
         pia       0.79      0.63      0.70       326
         sax       0.72      0.68      0.70       185
         tru       0.28      0.25      0.26       109
         vio       0.91      0.21      0.34       102
         voi       0.88      0.66      0.76       483

   micro avg       0.56      0.56      0.56      1573
   macro avg       0.41      0.37      0.36      1573
weighted avg       0.67      0.56      0.59      1573
 samples avg       0.47      0.56      0.50      1573

