In [None]:
import os
import numpy as np
import scipy
from scipy.io import wavfile
import scipy.fftpack as fft
from scipy.signal import get_window
import IPython.display as ipd
import matplotlib.pyplot as plt
import librosa
import librosa.display
import soundfile as sf
from tempfile import NamedTemporaryFile
from tqdm.auto import tqdm
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv, global_mean_pool
from sklearn.metrics import f1_score

%matplotlib inline

In [None]:
!pip install datasets soundfile torchaudio
!pip install --upgrade datasets torchcodec

Collecting datasets
  Downloading datasets-4.4.1-py3-none-any.whl.metadata (19 kB)
Collecting torchcodec
  Downloading torchcodec-0.9.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (11 kB)
Collecting pyarrow>=21.0.0 (from datasets)
  Downloading pyarrow-22.0.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.2 kB)
Downloading datasets-4.4.1-py3-none-any.whl (511 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m511.6/511.6 kB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading torchcodec-0.9.0-cp312-cp312-manylinux_2_28_x86_64.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m70.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyarrow-22.0.0-cp312-cp312-manylinux_2_28_x86_64.whl (47.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.7/47.7 MB[0m [31m56.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torchcodec, pyarrow, datasets
  Attempting uninstall: pyarrow
    Found

In [None]:
from datasets import load_dataset

ds = load_dataset("agkphysics/AudioSet", "balanced", split="train")

# GTZAN
gtzan_ds = load_dataset("mteb/gtzan-genre", split="train")

# UrbanSound8K
us8k_ds = load_dataset("danavery/urbansound8K", split="train")

print(ds)
print(gtzan_ds)
print(us8k_ds)

Resolving data files:   0%|          | 0/38 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/35 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/38 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/35 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/47 [00:00<?, ?it/s]

README.md: 0.00B [00:00, ?B/s]

Downloading data:   0%|          | 0/16 [00:00<?, ?files/s]

data/train-00000-of-00016-e478d7cccca6a0(…):   0%|          | 0.00/434M [00:00<?, ?B/s]

data/train-00001-of-00016-299138aa39afaa(…):   0%|          | 0.00/472M [00:00<?, ?B/s]

data/train-00002-of-00016-887e0748205b6f(…):   0%|          | 0.00/384M [00:00<?, ?B/s]

data/train-00003-of-00016-691ee48aa53d9c(…):   0%|          | 0.00/447M [00:00<?, ?B/s]

data/train-00004-of-00016-c0f37514d8e28a(…):   0%|          | 0.00/441M [00:00<?, ?B/s]

data/train-00005-of-00016-55ef1a0a51149c(…):   0%|          | 0.00/591M [00:00<?, ?B/s]

data/train-00006-of-00016-0ef363072505e6(…):   0%|          | 0.00/496M [00:00<?, ?B/s]

data/train-00007-of-00016-dfac173beb21e5(…):   0%|          | 0.00/588M [00:00<?, ?B/s]

data/train-00008-of-00016-2744487f32f65d(…):   0%|          | 0.00/493M [00:00<?, ?B/s]

data/train-00009-of-00016-83fc7364d47981(…):   0%|          | 0.00/549M [00:00<?, ?B/s]

data/train-00010-of-00016-4c1d0e285ed778(…):   0%|          | 0.00/353M [00:00<?, ?B/s]

data/train-00011-of-00016-79d186503a2667(…):   0%|          | 0.00/316M [00:00<?, ?B/s]

data/train-00012-of-00016-6aff88fdcca229(…):   0%|          | 0.00/372M [00:00<?, ?B/s]

data/train-00013-of-00016-17d827b1a5be04(…):   0%|          | 0.00/348M [00:00<?, ?B/s]

data/train-00014-of-00016-c630762df85f6c(…):   0%|          | 0.00/381M [00:00<?, ?B/s]

data/train-00015-of-00016-03506887d89adf(…):   0%|          | 0.00/335M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8732 [00:00<?, ? examples/s]

Dataset({
    features: ['video_id', 'audio', 'labels', 'human_labels'],
    num_rows: 18683
})
Dataset({
    features: ['audio', 'label'],
    num_rows: 1000
})
Dataset({
    features: ['audio', 'slice_file_name', 'fsID', 'start', 'end', 'salience', 'fold', 'classID', 'class'],
    num_rows: 8732
})


In [None]:
def feature_extraction(x, sample_rate, target_size=128, duration=10.0):
    """
    Returns a 128x128 MFCC 'image' for one audio clip.
    """
    # 1) Fix waveform length
    target_len = int(sample_rate * duration)
    if len(x) < target_len:
        x = np.pad(x, (0, target_len - len(x)))
    else:
        x = x[:target_len]

    # 2) Compute MFCCs: 128 coefficients -> 128 frequency bins
    mfcc = librosa.feature.mfcc(
        y=x,
        sr=sample_rate,
        n_mfcc=target_size
    )  # shape: (n_mfcc, T)

    # 3) Force the time dimension to 128 frames
    if mfcc.shape[1] < target_size:
        mfcc = librosa.util.fix_length(mfcc, size=target_size, axis=1)
    else:
        mfcc = mfcc[:, :target_size]

    # Results in 128x128 MFCC
    return mfcc.astype(np.float32)

In [None]:
from tqdm.auto import tqdm
from datasets import ClassLabel

def build_mfcc_npz_from_hf(hf_ds, label_key, npz_path,
                           duration=10.0, target_size=128):
    """
    Convert a HuggingFace audio dataset into:
      X: (N, 128, 128) MFCC 'images'
      Y: (N, C) one-hot labels
      label_vocab: list of label names

    hf_ds: loaded HF dataset (e.g., gtzan_ds)
    label_key: column name with the class (e.g., 'label' or 'class')
    """
    X_list = []
    labels_raw = []

    for ex in tqdm(hf_ds, desc=f"Building MFCCs for {npz_path}"):
        audio = ex["audio"]
        x = audio["array"]
        sr = audio["sampling_rate"]

        mfcc_img = feature_extraction(x, sr,
                                      target_size=target_size,
                                      duration=duration)
        X_list.append(mfcc_img)
        labels_raw.append(ex[label_key])

    X = np.stack(X_list).astype(np.float32)

    feat = hf_ds.features.get(label_key, None)

    if isinstance(feat, ClassLabel):
        label_vocab = list(feat.names)
        label_ids = np.array(labels_raw, dtype=np.int64)
    else:
        label_vocab = sorted(set(labels_raw))
        label_to_idx = {lab: i for i, lab in enumerate(label_vocab)}
        label_ids = np.array([label_to_idx[lab] for lab in labels_raw],
                             dtype=np.int64)

    num_classes = len(label_vocab)
    N = len(label_ids)

    Y = np.zeros((N, num_classes), dtype=np.float32)
    Y[np.arange(N), label_ids] = 1.0

    np.savez_compressed(
        npz_path,
        X=X,
        Y=Y,
        label_vocab=np.array(label_vocab, dtype=object)
    )
    print(f"Saved {X.shape[0]} examples, {num_classes} classes to {npz_path}")

In [None]:
build_mfcc_npz_from_hf(
    hf_ds=gtzan_ds,
    label_key="label",
    npz_path="gtzan_mfcc128x128.npz",
    duration=30.0,
    target_size=128
)

build_mfcc_npz_from_hf(
    hf_ds=us8k_ds,
    label_key="class",
    npz_path="urbansound8k_mfcc128x128.npz",
    duration=4.0,
    target_size=128
)

Building MFCCs for gtzan_mfcc128x128.npz:   0%|          | 0/1000 [00:00<?, ?it/s]

Saved 1000 examples, 10 classes to gtzan_mfcc128x128.npz


Building MFCCs for urbansound8k_mfcc128x128.npz:   0%|          | 0/8732 [00:00<?, ?it/s]

  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


Saved 8732 examples, 10 classes to urbansound8k_mfcc128x128.npz


In [None]:
def build_single_label_mfcc_npz(
    hf_dataset_name,
    split,
    label_key,
    npz_path,
    duration=10.0,
    target_size=128
):
    """
    Generic helper to turn a HuggingFace audio dataset with a *single label per clip*
    into an MFCC 128x128 dataset saved as .npz.

    - hf_dataset_name: e.g. "marsyas/gtzan"
    - split: e.g. "train"
    - label_key: key for the label in each example (e.g., "genre", "class", "label")
    - npz_path: path to save e.g. "gtzan_mfcc128x128.npz"
    """
    ds = load_dataset(hf_dataset_name, split=split)

    X_list = []
    labels_list = []

    print("Building MFCCs for", hf_dataset_name, "split:", split)
    for ex in tqdm(ds):
        audio = ex["audio"]
        x = audio["array"]
        sr = audio["sampling_rate"]

        # 128x128 MFCC image (reuse your feature_extraction)
        mfcc_img = feature_extraction(x, sr, target_size=target_size, duration=duration)
        X_list.append(mfcc_img)

        lab = ex[label_key]    # single label
        labels_list.append(str(lab))

    X = np.stack(X_list)  # (N, 128, 128)

    # Build label vocabulary and one-hot label matrix
    label_vocab = sorted(set(labels_list))
    label_to_idx = {lab: i for i, lab in enumerate(label_vocab)}
    num_classes = len(label_vocab)

    Y = np.zeros((len(labels_list), num_classes), dtype=np.float32)
    for i, lab in enumerate(labels_list):
        Y[i, label_to_idx[lab]] = 1.0

    print("X shape:", X.shape, "Y shape:", Y.shape, "#classes:", num_classes)
    print("Labels:", label_vocab)

    np.savez_compressed(
        npz_path,
        X=X,
        Y=Y,
        label_vocab=np.array(label_vocab, dtype=object)
    )
    print("Saved to", npz_path)

In [None]:

X = []
Y = []
IDS = []

for example in tqdm(ds):
    audio = example["audio"]
    x = audio["array"]
    sr = audio["sampling_rate"]

    vid = example["video_id"]
    human_labels = example["human_labels"]

    mfcc_img = feature_extraction(x, sr)

    X.append(mfcc_img)
    Y.append(human_labels)
    IDS.append(vid)

X = np.stack(X)
IDS = np.array(IDS)

print("Feature tensor shape:", X.shape)
print("Example labels:", Y[0])
print("Example id:", IDS[0])

  0%|          | 0/18683 [00:00<?, ?it/s]

Feature tensor shape: (18683, 128, 128)
Example labels: ['Speech', 'Gush']
Example id: --PJHxphWEs


In [None]:
# 1) Build vocabulary
from itertools import chain

all_labels = sorted(set(chain.from_iterable(Y)))
label_to_idx = {lab: i for i, lab in enumerate(all_labels)}
num_classes = len(all_labels)
print("Num classes:", num_classes)

# 2) Convert label lists to multi-hot vectors
def labels_to_multihot(label_list):
    vec = np.zeros(num_classes, dtype=np.float32)
    for lab in label_list:
        if lab in label_to_idx:
            vec[label_to_idx[lab]] = 1.0
    return vec

Y_multihot = np.stack([labels_to_multihot(labs) for labs in Y])
print("Y_multihot shape:", Y_multihot.shape)

Num classes: 527
Y_multihot shape: (18683, 527)


In [None]:

# NEW: restrict to top-K most frequent labels to simplify task
TOP_K = 20

label_freq = Y_multihot.sum(axis=0)
topk_idx = np.argsort(label_freq)[::-1][:TOP_K]

all_labels_array = np.array(all_labels, dtype=object)
print("Top-K indices:", topk_idx)
print("Top-K labels:", all_labels_array[topk_idx])

# keep only those K labels
Y_top = Y_multihot[:, topk_idx]

# drop samples with no positive among these K labels
mask = Y_top.sum(axis=1) > 0
X_top   = X[mask]
Y_top   = Y_top[mask]
IDS_top = IDS[mask]

label_vocab_top = all_labels_array[topk_idx]

print("Shapes after top-K filtering:", X_top.shape, Y_top.shape)

# Save this *reduced* dataset instead

np.savez_compressed(
    "audioset_balanced_mfcc128x128_topK.npz",
    X=X_top,
    Y=Y_top,
    ids=IDS_top,
    label_vocab=label_vocab_top
)

Top-K indices: [291 418 486  15 254 297 390 144 218 320  68 143 312 516  48 495 308 307
  55 170]
Top-K labels: ['Music' 'Speech' 'Vehicle' 'Animal' 'Inside, small room'
 'Musical instrument' 'Singing' 'Domestic animals, pets' 'Guitar'
 'Plucked string instrument' 'Car' 'Dog' 'Percussion'
 'Wind instrument, woodwind instrument' 'Boat, Water vehicle' 'Water'
 'Outside, urban or manmade' 'Outside, rural or natural'
 'Brass instrument' 'Engine']
Shapes after top-K filtering: (10975, 128, 128) (10975, 20)


In [None]:
build_single_label_mfcc_npz(
    hf_dataset_name="mteb/gtzan-genre",
    split="train",                       # this dataset has a single 'train' split
    label_key="label",                   # <-- column name is 'label', not 'genre'
    npz_path="gtzan_mfcc128x128.npz",
    duration=30.0,
    target_size=128
)

Building MFCCs for mteb/gtzan-genre split: train


  0%|          | 0/1000 [00:00<?, ?it/s]

X shape: (1000, 128, 128) Y shape: (1000, 10) #classes: 10
Labels: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
Saved to gtzan_mfcc128x128.npz


After extraction, the dataset will be in a directory named `genres`. Each subdirectory within `genres` represents a genre (e.g., `blues`, `classical`, `metal`), and contains `.wav` files for that genre.

Now, we will reuse the `feature_extraction` function to convert these audio files into MFCC images and then save them into the `gtzan_mfcc128x128.npz` file, similar to how `build_single_label_mfcc_npz` would have done.

In [None]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m34.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.7.0


In [None]:
build_single_label_mfcc_npz(
    hf_dataset_name="mteb/gtzan-genre",
    split="train",            # this dataset has a single 'train' split
    label_key="label",        # <-- column name is 'label' (not 'genre')
    npz_path="gtzan_mfcc128x128.npz",
    duration=30.0,
    target_size=128
)

Building MFCCs for mteb/gtzan-genre split: train


  0%|          | 0/1000 [00:00<?, ?it/s]

X shape: (1000, 128, 128) Y shape: (1000, 10) #classes: 10
Labels: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
Saved to gtzan_mfcc128x128.npz


In [None]:
build_single_label_mfcc_npz(
    hf_dataset_name="danavery/urbansound8K",
    split="train",                            # this repo has only "train"
    label_key="class",                        # column name from the dataset card
    npz_path="urbansound8k_mfcc128x128.npz",
    duration=4.0,
    target_size=128
)

Building MFCCs for danavery/urbansound8K split: train


  0%|          | 0/8732 [00:00<?, ?it/s]

  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)
  mel_basis = filters.mel(sr=sr, n_fft=n_fft, **kwargs)


X shape: (8732, 128, 128) Y shape: (8732, 10) #classes: 10
Labels: ['air_conditioner', 'car_horn', 'children_playing', 'dog_bark', 'drilling', 'engine_idling', 'gun_shot', 'jackhammer', 'siren', 'street_music']
Saved to urbansound8k_mfcc128x128.npz


In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Load precomputed MFCC data
data = np.load("audioset_balanced_mfcc128x128_topK.npz", allow_pickle=True)
X = data["X"]
Y = data["Y"]
label_vocab = data["label_vocab"]

num_classes = Y.shape[1]
print("X shape:", X.shape, "Y shape:", Y.shape, "#classes:", num_classes)
print("Labels:", label_vocab)


class MFCCDataset(Dataset):
    """Simple dataset: 128x128 MFCC image + multi-hot label vector"""
    def __init__(self, X, Y):
        self.X = torch.from_numpy(X).unsqueeze(1).float()
        self.Y = torch.from_numpy(Y).float()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]


full_dataset = MFCCDataset(X, Y)

N = len(full_dataset)
n_train = int(0.7 * N)
n_val = int(0.15 * N)
n_test = N - n_train - n_val
train_ds, val_ds, test_ds = random_split(full_dataset, [n_train, n_val, n_test])

batch_size = 32
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_ds,   batch_size=batch_size*2, shuffle=False, num_workers=2)
test_loader  = DataLoader(test_ds,  batch_size=batch_size*2, shuffle=False, num_workers=2)

Using device: cuda
X shape: (10975, 128, 128) Y shape: (10975, 20) #classes: 20
Labels: ['Music' 'Speech' 'Vehicle' 'Animal' 'Inside, small room'
 'Musical instrument' 'Singing' 'Domestic animals, pets' 'Guitar'
 'Plucked string instrument' 'Car' 'Dog' 'Percussion'
 'Wind instrument, woodwind instrument' 'Boat, Water vehicle' 'Water'
 'Outside, urban or manmade' 'Outside, rural or natural'
 'Brass instrument' 'Engine']


In [None]:
def make_loaders_from_npz(npz_path, batch_size=32,
                          train_frac=0.7, val_frac=0.15):
    """
    Load X, Y from an .npz and return:
      train_loader, val_loader, test_loader, num_classes, label_vocab
    """
    data = np.load(npz_path, allow_pickle=True)
    X = data["X"]
    Y = data["Y"]
    label_vocab = data["label_vocab"]
    num_classes = Y.shape[1]

    full_dataset = MFCCDataset(X, Y)
    N = len(full_dataset)
    n_train = int(train_frac * N)
    n_val = int(val_frac * N)
    n_test = N - n_train - n_val

    train_ds, val_ds, test_ds = random_split(full_dataset,
                                             [n_train, n_val, n_test])

    train_loader = DataLoader(train_ds, batch_size=batch_size,
                              shuffle=True, num_workers=2)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size*2,
                              shuffle=False, num_workers=2)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size*2,
                              shuffle=False, num_workers=2)

    return train_loader, val_loader, test_loader, num_classes, label_vocab

In [None]:
class CNNBackbone(nn.Module):
    """
    Small CNN that turns a (1, 128, 128) MFCC image into a feature map.
    Output feature map will be (C_out, H_out, W_out).
    """
    def __init__(self, in_channels=1, base_channels=32):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, base_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),          # 128 -> 64

            nn.Conv2d(base_channels, base_channels * 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_channels * 2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),          # 64 -> 32

            nn.Conv2d(base_channels * 2, base_channels * 4, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_channels * 4),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),          # 32 -> 16  (grid will be 16x16)
        )
        self.out_channels = base_channels * 4  # 128

    def forward(self, x):
        return self.features(x)

In [None]:

def make_grid_edge_index(height, width, device=None):
    """
    Build 4-neighbour grid edges for an HxW grid of nodes.
    Nodes are indexed row-major: i * width + j.
    Returns edge_index of shape (2, E).
    """
    edges = []
    def node_id(i, j):
        return i * width + j

    for i in range(height):
        for j in range(width):
            u = node_id(i, j)
            # down neighbor
            if i + 1 < height:
                v = node_id(i + 1, j)
                edges.append((u, v))
                edges.append((v, u))
            # right neighbor
            if j + 1 < width:
                v = node_id(i, j + 1)
                edges.append((u, v))
                edges.append((v, u))

    edge_index = torch.tensor(edges, dtype=torch.long).t()
    if device is not None:
        edge_index = edge_index.to(device)
    return edge_index

In [None]:
class HybridRCAModel(nn.Module):
    """
    Hybrid Relational-Convolutional-Attention model:
    - CNN on 128x128 MFCC image
    - Interpret feature map as grid graph nodes
    - GAT layers + global pooling
    - MLP classifier
    """
    def __init__(self, num_classes, hidden_dim=64, dropout=0.3):
        super().__init__()
        self.cnn = CNNBackbone()

        # Infer feature-map size (H_out, W_out) using a dummy forward
        with torch.no_grad():
            dummy = torch.zeros(1, 1, 128, 128)
            fm = self.cnn(dummy)
        _, C, H, W = fm.shape
        self.cnn_out_channels = C
        self.grid_h = H
        self.grid_w = W
        self.num_nodes = H * W

        # Precompute single-graph 4-neighbour edges and register as buffer
        edge_index_single = make_grid_edge_index(H, W)
        self.register_buffer("edge_index_single", edge_index_single)

        # GAT layers
        self.gat1 = GATv2Conv(self.cnn_out_channels, hidden_dim, heads=1, concat=True)
        self.gat2 = GATv2Conv(hidden_dim, hidden_dim, heads=1, concat=True)

        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x):
        """
        x: (B, 1, 128, 128)
        returns logits: (B, num_classes)
        """
        B = x.size(0)

        # 1) CNN feature map: (B, C, H, W)
        fm = self.cnn(x)
        B, C, H, W = fm.shape
        N = H * W

        # 2) Flatten to node features: (B*N, C)
        nodes = fm.view(B, C, N).permute(0, 2, 1).contiguous()
        nodes = nodes.view(B * N, C)

        # 3) Build batched edge_index and batch vector
        E = self.edge_index_single.size(1)
        device = self.edge_index_single.device

        # offsets for each graph in the batch
        offsets = (torch.arange(B, device=device).repeat_interleave(E) * N)
        edge_index = self.edge_index_single.repeat(1, B) + offsets

        # batch vector: which graph each node belongs to
        batch = torch.arange(B, device=device).repeat_interleave(N)

        # 4) GAT layers
        x_g = F.elu(self.gat1(nodes, edge_index))
        x_g = self.dropout(x_g)
        x_g = F.elu(self.gat2(x_g, edge_index))
        x_g = self.dropout(x_g)

        # 5) Global pooling over nodes -> graph embedding (B, hidden_dim)
        graph_emb = global_mean_pool(x_g, batch)

        # 6) MLP classifier
        logits = self.fc(graph_emb)
        return logits

In [None]:
def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)

        optimizer.zero_grad()
        logits = model(xb)
        loss = F.binary_cross_entropy_with_logits(logits, yb)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * xb.size(0)

    avg_loss = total_loss / len(loader.dataset)
    return avg_loss


def evaluate(model, loader, device):
    model.eval()
    total_loss = 0.0
    all_targets = []
    all_preds = []

    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device)

            logits = model(xb)
            loss = F.binary_cross_entropy_with_logits(logits, yb)
            total_loss += loss.item() * xb.size(0)

            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()

            all_targets.append(yb.cpu().numpy())
            all_preds.append(preds.cpu().numpy())

    avg_loss = total_loss / len(loader.dataset)
    all_targets = np.concatenate(all_targets, axis=0)
    all_preds = np.concatenate(all_preds, axis=0)

    macro_f1 = f1_score(all_targets, all_preds, average="macro", zero_division=0)
    return avg_loss, macro_f1


def train_model(model, train_loader, val_loader, device,
                epochs=20, lr=1e-3):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    best_val_f1 = 0.0
    best_state = None

    for epoch in range(1, epochs + 1):
        train_loss = train_one_epoch(model, train_loader, optimizer, device)
        val_loss, val_f1 = evaluate(model, val_loader, device)

        print(f"Epoch {epoch:02d} | "
              f"train_loss={train_loss:.4f} | "
              f"val_loss={val_loss:.4f} | "
              f"val_macroF1={val_f1:.4f}")

        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}

    if best_state is not None:
        model.load_state_dict(best_state)

    print("Best validation macro F1:", best_val_f1)
    return model, best_val_f1

In [None]:
hybrid = HybridRCAModel(num_classes=num_classes, hidden_dim=64, dropout=0.3)
hybrid, best_val_f1 = train_model(hybrid, train_loader, val_loader, device,
                                  epochs=15, lr=1e-3)

Epoch 01 | train_loss=0.2083 | val_loss=0.1689 | val_macroF1=0.0716
Epoch 02 | train_loss=0.1664 | val_loss=0.1643 | val_macroF1=0.0711
Epoch 03 | train_loss=0.1603 | val_loss=0.1587 | val_macroF1=0.0709
Epoch 04 | train_loss=0.1573 | val_loss=0.1602 | val_macroF1=0.0750
Epoch 05 | train_loss=0.1559 | val_loss=0.1556 | val_macroF1=0.0756
Epoch 06 | train_loss=0.1541 | val_loss=0.1595 | val_macroF1=0.0892
Epoch 07 | train_loss=0.1536 | val_loss=0.1673 | val_macroF1=0.0735
Epoch 08 | train_loss=0.1510 | val_loss=0.1518 | val_macroF1=0.0822
Epoch 09 | train_loss=0.1492 | val_loss=0.1539 | val_macroF1=0.0768
Epoch 10 | train_loss=0.1472 | val_loss=0.1527 | val_macroF1=0.0825
Epoch 11 | train_loss=0.1473 | val_loss=0.1552 | val_macroF1=0.0740
Epoch 12 | train_loss=0.1455 | val_loss=0.1503 | val_macroF1=0.0842
Epoch 13 | train_loss=0.1445 | val_loss=0.1519 | val_macroF1=0.0943
Epoch 14 | train_loss=0.1440 | val_loss=0.1570 | val_macroF1=0.0821
Epoch 15 | train_loss=0.1420 | val_loss=0.1583 |

In [None]:
class CNNAudioClassifier(nn.Module):
    """
    CNN-only baseline:
    - Same CNN backbone
    - Global average pooling
    - MLP classifier
    """
    def __init__(self, num_classes, hidden_dim=128, dropout=0.3):
        super().__init__()
        self.cnn = CNNBackbone()
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Sequential(
            nn.Linear(self.cnn.out_channels, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x):
        fm = self.cnn(x)
        pooled = fm.mean(dim=(2, 3))
        pooled = self.dropout(pooled)
        logits = self.fc(pooled)
        return logits

# Example training:
cnn_baseline = CNNAudioClassifier(num_classes=num_classes, hidden_dim=128, dropout=0.3)
cnn_baseline, cnn_val_f1 = train_model(cnn_baseline, train_loader, val_loader,
                                       device, epochs=15, lr=1e-3)

Epoch 01 | train_loss=0.2020 | val_loss=0.1822 | val_macroF1=0.0613
Epoch 02 | train_loss=0.1681 | val_loss=0.1637 | val_macroF1=0.0705
Epoch 03 | train_loss=0.1624 | val_loss=0.1755 | val_macroF1=0.0826
Epoch 04 | train_loss=0.1609 | val_loss=0.1639 | val_macroF1=0.0728
Epoch 05 | train_loss=0.1587 | val_loss=0.1572 | val_macroF1=0.0735
Epoch 06 | train_loss=0.1578 | val_loss=0.1618 | val_macroF1=0.0741
Epoch 07 | train_loss=0.1563 | val_loss=0.1584 | val_macroF1=0.0754
Epoch 08 | train_loss=0.1548 | val_loss=0.1649 | val_macroF1=0.0685
Epoch 09 | train_loss=0.1547 | val_loss=0.1565 | val_macroF1=0.0771
Epoch 10 | train_loss=0.1557 | val_loss=0.1549 | val_macroF1=0.0753
Epoch 11 | train_loss=0.1536 | val_loss=0.1695 | val_macroF1=0.0722
Epoch 12 | train_loss=0.1532 | val_loss=0.1546 | val_macroF1=0.0808
Epoch 13 | train_loss=0.1525 | val_loss=0.1568 | val_macroF1=0.0793
Epoch 14 | train_loss=0.1515 | val_loss=0.1607 | val_macroF1=0.0873
Epoch 15 | train_loss=0.1508 | val_loss=0.1663 |

In [None]:
class GNNOnlyClassifier(nn.Module):
    """
    GNN-only baseline:
    - Downsample MFCC image to 16x16
    - Each pixel = node with 1 feature
    - GAT + pooling + MLP
    """
    def __init__(self, num_classes, hidden_dim=64, dropout=0.3):
        super().__init__()
        self.grid_h = 16
        self.grid_w = 16
        self.num_nodes = self.grid_h * self.grid_w

        edge_index_single = make_grid_edge_index(self.grid_h, self.grid_w)
        self.register_buffer("edge_index_single", edge_index_single)

        self.lin_in = nn.Linear(1, hidden_dim)
        self.gat1 = GATv2Conv(hidden_dim, hidden_dim, heads=1, concat=True)
        self.gat2 = GATv2Conv(hidden_dim, hidden_dim, heads=1, concat=True)

        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x):
        B = x.size(0)

        # Downsample to 16x16
        x_ds = F.adaptive_avg_pool2d(x, (self.grid_h, self.grid_w))
        N = self.num_nodes

        # Node features: (B*N, 1)
        nodes = x_ds.view(B, 1, N).permute(0, 2, 1).contiguous()
        nodes = nodes.view(B * N, 1)
        nodes = self.lin_in(nodes)

        # Batch edges
        E = self.edge_index_single.size(1)
        device = self.edge_index_single.device
        offsets = (torch.arange(B, device=device).repeat_interleave(E) * N)
        edge_index = self.edge_index_single.repeat(1, B) + offsets
        batch = torch.arange(B, device=device).repeat_interleave(N)

        # GAT
        x_g = F.elu(self.gat1(nodes, edge_index))
        x_g = self.dropout(x_g)
        x_g = F.elu(self.gat2(x_g, edge_index))
        x_g = self.dropout(x_g)

        graph_emb = global_mean_pool(x_g, batch)
        logits = self.fc(graph_emb)
        return logits

# Example training:
gnn_only = GNNOnlyClassifier(num_classes=num_classes, hidden_dim=64, dropout=0.3)
gnn_only, gnn_val_f1 = train_model(gnn_only, train_loader, val_loader,
                                   device, epochs=15, lr=1e-3)

Epoch 01 | train_loss=0.2360 | val_loss=0.1882 | val_macroF1=0.0534
Epoch 02 | train_loss=0.1897 | val_loss=0.1848 | val_macroF1=0.0526
Epoch 03 | train_loss=0.1869 | val_loss=0.1832 | val_macroF1=0.0569
Epoch 04 | train_loss=0.1853 | val_loss=0.1820 | val_macroF1=0.0569
Epoch 05 | train_loss=0.1828 | val_loss=0.1811 | val_macroF1=0.0573
Epoch 06 | train_loss=0.1818 | val_loss=0.1807 | val_macroF1=0.0599
Epoch 07 | train_loss=0.1804 | val_loss=0.1789 | val_macroF1=0.0642
Epoch 08 | train_loss=0.1792 | val_loss=0.1796 | val_macroF1=0.0629
Epoch 09 | train_loss=0.1799 | val_loss=0.1795 | val_macroF1=0.0629
Epoch 10 | train_loss=0.1782 | val_loss=0.1785 | val_macroF1=0.0648
Epoch 11 | train_loss=0.1785 | val_loss=0.1770 | val_macroF1=0.0650
Epoch 12 | train_loss=0.1781 | val_loss=0.1770 | val_macroF1=0.0648
Epoch 13 | train_loss=0.1783 | val_loss=0.1786 | val_macroF1=0.0616
Epoch 14 | train_loss=0.1776 | val_loss=0.1765 | val_macroF1=0.0646
Epoch 15 | train_loss=0.1771 | val_loss=0.1785 |

In [None]:
def evaluate_ensemble(model_a, model_b, loader, device):
    model_a.eval()
    model_b.eval()
    all_targets = []
    all_preds = []

    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device)

            logits_a = model_a(xb)
            logits_b = model_b(xb)

            probs_a = torch.sigmoid(logits_a)
            probs_b = torch.sigmoid(logits_b)
            probs = 0.5 * (probs_a + probs_b)

            preds = (probs > 0.5).float()

            all_targets.append(yb.cpu().numpy())
            all_preds.append(preds.cpu().numpy())

    all_targets = np.concatenate(all_targets, axis=0)
    all_preds = np.concatenate(all_preds, axis=0)
    macro_f1 = f1_score(all_targets, all_preds, average="macro", zero_division=0)
    return macro_f1

ensemble_val_f1 = evaluate_ensemble(cnn_baseline.to(device),
                                    hybrid.to(device),
                                    val_loader,
                                    device)
print("Ensemble validation macro F1:", ensemble_val_f1)

Ensemble validation macro F1: 0.08952777180118529


In [None]:
from itertools import product

hidden_dims = [64, 128]
dropouts    = [0.3, 0.5]
lrs         = [1e-3, 3e-4]
epochs      = 10
results = []

for h, d, lr in product(hidden_dims, dropouts, lrs):
    cfg = {"hidden_dim": h, "dropout": d, "lr": lr, "epochs": epochs}
    print("\n=== Config:", cfg, "===")

    model = HybridRCAModel(num_classes=num_classes,
                           hidden_dim=h,
                           dropout=d)

    model, val_f1 = train_model(model, train_loader, val_loader,
                                device, epochs=epochs, lr=lr)
    results.append((cfg, val_f1))

print("\n=== Grid search results ===")
for cfg, f1score in results:
    print(cfg, "-> val_macroF1 =", f1score)


=== Config: {'hidden_dim': 64, 'dropout': 0.3, 'lr': 0.001, 'epochs': 10} ===
Epoch 01 | train_loss=0.2082 | val_loss=0.1682 | val_macroF1=0.0708
Epoch 02 | train_loss=0.1682 | val_loss=0.1613 | val_macroF1=0.0712
Epoch 03 | train_loss=0.1611 | val_loss=0.1618 | val_macroF1=0.0733
Epoch 04 | train_loss=0.1572 | val_loss=0.1582 | val_macroF1=0.0753
Epoch 05 | train_loss=0.1570 | val_loss=0.1582 | val_macroF1=0.0783
Epoch 06 | train_loss=0.1546 | val_loss=0.1559 | val_macroF1=0.0809
Epoch 07 | train_loss=0.1531 | val_loss=0.1640 | val_macroF1=0.0716
Epoch 08 | train_loss=0.1520 | val_loss=0.1549 | val_macroF1=0.0811
Epoch 09 | train_loss=0.1506 | val_loss=0.1563 | val_macroF1=0.0817
Epoch 10 | train_loss=0.1489 | val_loss=0.1540 | val_macroF1=0.0852
Best validation macro F1: 0.08516096215598054

=== Config: {'hidden_dim': 64, 'dropout': 0.3, 'lr': 0.0003, 'epochs': 10} ===
Epoch 01 | train_loss=0.2686 | val_loss=0.1731 | val_macroF1=0.0702
Epoch 02 | train_loss=0.1780 | val_loss=0.1768 

In [None]:
from sklearn.metrics import (
    f1_score,
    accuracy_score,
    average_precision_score,
)

def evaluate_full(model, loader, device, threshold=0.5):
    """
    Evaluate a trained model on a given DataLoader.

    Returns a dict with:
      - loss (BCE)
      - macro_f1
      - micro_f1
      - subset_accuracy
      - macro_AP (mean average precision per label)
    """
    model.eval()
    all_targets = []
    all_probs = []
    total_loss = 0.0

    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device)

            logits = model(xb)
            loss = F.binary_cross_entropy_with_logits(logits, yb)
            total_loss += loss.item() * xb.size(0)

            probs = torch.sigmoid(logits)

            all_targets.append(yb.cpu().numpy())
            all_probs.append(probs.cpu().numpy())

    # Stack over all batches
    y_true = np.concatenate(all_targets, axis=0)
    y_scores = np.concatenate(all_probs, axis=0)  =

    # Binarize with a threshold
    y_pred = (y_scores >= threshold).astype(np.float32)

    avg_loss = total_loss / len(loader.dataset)

    # Metrics
    macro_f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)
    micro_f1 = f1_score(y_true, y_pred, average="micro", zero_division=0)

    # "Exact match" accuracy: predicted set == true set
    subset_accuracy = accuracy_score(y_true, y_pred)

    # Mean Average Precision over labels
    # (ignore warnings if some labels have no positives)
    macro_AP = average_precision_score(y_true, y_scores, average="macro")

    metrics = {
        "loss": avg_loss,
        "macro_f1": macro_f1,
        "micro_f1": micro_f1,
        "subset_accuracy": subset_accuracy,
        "macro_AP": macro_AP,
    }
    return metrics

In [None]:

def find_best_threshold(model, val_loader, device, thresholds=None):
    """
    Search over a list of thresholds on the validation set and
    return (best_threshold, best_metrics_dict).
    Uses evaluate_full(...) under the hood.
    """
    if thresholds is None:
        thresholds = np.linspace(0.05, 0.5, 10)

    best_thr = None
    best_metrics = None
    best_macro_f1 = -1.0

    print("Tuning threshold...")
    for thr in thresholds:
        metrics = evaluate_full(model, val_loader, device, threshold=thr)
        macro_f1 = metrics["macro_f1"]
        micro_f1 = metrics["micro_f1"]
        print(f"thr={thr:.2f} -> macro_f1={macro_f1:.4f}, micro_f1={micro_f1:.4f}")

        if macro_f1 > best_macro_f1:
            best_macro_f1 = macro_f1
            best_thr = thr
            best_metrics = metrics

    print("Best threshold on val:", best_thr, "with macro_f1=", best_macro_f1)
    return best_thr, best_metrics


def find_best_threshold_ensemble(model_a, model_b, val_loader, device, thresholds=None):
    """
    Same as above, but for an ensemble of two models.
    Uses evaluate_ensemble_full(...) under the hood.
    """
    if thresholds is None:
        thresholds = np.linspace(0.05, 0.5, 10)

    best_thr = None
    best_metrics = None
    best_macro_f1 = -1.0

    print("Tuning ensemble threshold...")
    for thr in thresholds:
        metrics = evaluate_ensemble_full(model_a, model_b, val_loader, device,
                                         threshold=thr)
        macro_f1 = metrics["macro_f1"]
        micro_f1 = metrics["micro_f1"]
        print(f"thr={thr:.2f} -> macro_f1={macro_f1:.4f}, micro_f1={micro_f1:.4f}")

        if macro_f1 > best_macro_f1:
            best_macro_f1 = macro_f1
            best_thr = thr
            best_metrics = metrics

    print("Best ensemble threshold on val:", best_thr, "with macro_f1=", best_macro_f1)
    return best_thr, best_metrics

In [None]:
# ===================== GTZAN experiment =====================
gtzan_train_loader, gtzan_val_loader, gtzan_test_loader, \
gtzan_num_classes, gtzan_label_vocab = make_loaders_from_npz(
    "gtzan_mfcc128x128.npz", batch_size=32
)

print("GTZAN num_classes:", gtzan_num_classes)
print("GTZAN labels:", gtzan_label_vocab)

# ---- Train models on GTZAN ----
gtzan_hybrid = HybridRCAModel(num_classes=gtzan_num_classes,
                              hidden_dim=128, dropout=0.3)
gtzan_hybrid, _ = train_model(gtzan_hybrid,
                              gtzan_train_loader, gtzan_val_loader,
                              device, epochs=15, lr=1e-3)

gtzan_cnn = CNNAudioClassifier(num_classes=gtzan_num_classes,
                               hidden_dim=128, dropout=0.3)
gtzan_cnn, _ = train_model(gtzan_cnn,
                           gtzan_train_loader, gtzan_val_loader,
                           device, epochs=15, lr=1e-3)

gtzan_gnn = GNNOnlyClassifier(num_classes=gtzan_num_classes,
                              hidden_dim=64, dropout=0.3)
gtzan_gnn, _ = train_model(gtzan_gnn,
                           gtzan_train_loader, gtzan_val_loader,
                           device, epochs=15, lr=1e-3)

# ---- Threshold tuning on validation set (reuse your loop) ----
thresholds = np.linspace(0.05, 0.5, 10)

print("\nTuning threshold for GTZAN HYBRID...")
best_thr_h, best_metrics_h = None, None
best_macro_f1 = -1
for thr in thresholds:
    m = evaluate_full(gtzan_hybrid, gtzan_val_loader, device, threshold=thr)
    print(f"thr={thr:.2f} -> macro_f1={m['macro_f1']:.4f}, micro_f1={m['micro_f1']:.4f}")
    if m['macro_f1'] > best_macro_f1:
        best_macro_f1 = m['macro_f1']
        best_thr_h = thr
        best_metrics_h = m

print("\n[GTZAN] Best HYBRID threshold on val:", best_thr_h, best_metrics_h)

print("\nTuning threshold for GTZAN CNN...")
best_thr_c, best_metrics_c = None, None
best_macro_f1 = -1
for thr in thresholds:
    m = evaluate_full(gtzan_cnn, gtzan_val_loader, device, threshold=thr)
    print(f"thr={thr:.2f} -> macro_f1={m['macro_f1']:.4f}, micro_f1={m['micro_f1']:.4f}")
    if m['macro_f1'] > best_macro_f1:
        best_macro_f1 = m['macro_f1']
        best_thr_c = thr
        best_metrics_c = m

print("\n[GTZAN] Best CNN threshold on val:", best_thr_c, best_metrics_c)

print("\nTuning threshold for GTZAN GNN...")
best_thr_g, best_metrics_g = None, None
best_macro_f1 = -1
for thr in thresholds:
    m = evaluate_full(gtzan_gnn, gtzan_val_loader, device, threshold=thr)
    print(f"thr={thr:.2f} -> macro_f1={m['macro_f1']:.4f}, micro_f1={m['micro_f1']:.4f}")
    if m['macro_f1'] > best_macro_f1:
        best_macro_f1 = m['macro_f1']
        best_thr_g = thr
        best_metrics_g = m

print("\n[GTZAN] Best GNN threshold on val:", best_thr_g, best_metrics_g)

# ---- Final test metrics on GTZAN with tuned thresholds ----
print("\nGTZAN HYBRID (test, tuned):")
print(evaluate_full(gtzan_hybrid, gtzan_test_loader, device, threshold=best_thr_h))

print("\nGTZAN CNN (test, tuned):")
print(evaluate_full(gtzan_cnn, gtzan_test_loader, device, threshold=best_thr_c))

print("\nGTZAN GNN (test, tuned):")
print(evaluate_full(gtzan_gnn, gtzan_test_loader, device, threshold=best_thr_g))

GTZAN num_classes: 10
GTZAN labels: ['0' '1' '2' '3' '4' '5' '6' '7' '8' '9']
Epoch 01 | train_loss=0.4097 | val_loss=0.3307 | val_macroF1=0.0000
Epoch 02 | train_loss=0.3178 | val_loss=0.2945 | val_macroF1=0.0353
Epoch 03 | train_loss=0.2923 | val_loss=0.2832 | val_macroF1=0.1350
Epoch 04 | train_loss=0.2741 | val_loss=0.2674 | val_macroF1=0.1555
Epoch 05 | train_loss=0.2550 | val_loss=0.2465 | val_macroF1=0.1561
Epoch 06 | train_loss=0.2417 | val_loss=0.2495 | val_macroF1=0.1419
Epoch 07 | train_loss=0.2347 | val_loss=0.2450 | val_macroF1=0.2125
Epoch 08 | train_loss=0.2307 | val_loss=0.2307 | val_macroF1=0.1962
Epoch 09 | train_loss=0.2204 | val_loss=0.2786 | val_macroF1=0.1820
Epoch 10 | train_loss=0.2159 | val_loss=0.2237 | val_macroF1=0.2633
Epoch 11 | train_loss=0.2062 | val_loss=0.2223 | val_macroF1=0.2575
Epoch 12 | train_loss=0.2014 | val_loss=0.2962 | val_macroF1=0.1928
Epoch 13 | train_loss=0.1986 | val_loss=0.2170 | val_macroF1=0.3022
Epoch 14 | train_loss=0.1969 | val_los

In [None]:
# ===================== UrbanSound8K experiment =====================
us_train_loader, us_val_loader, us_test_loader, \
us_num_classes, us_label_vocab = make_loaders_from_npz(
    "urbansound8k_mfcc128x128.npz", batch_size=32
)

print("UrbanSound8K num_classes:", us_num_classes)
print("UrbanSound8K labels:", us_label_vocab)

# ---- Train models on UrbanSound8K ----
us_hybrid = HybridRCAModel(num_classes=us_num_classes,
                           hidden_dim=128, dropout=0.3)
us_hybrid, _ = train_model(us_hybrid,
                           us_train_loader, us_val_loader,
                           device, epochs=15, lr=1e-3)

us_cnn = CNNAudioClassifier(num_classes=us_num_classes,
                            hidden_dim=128, dropout=0.3)
us_cnn, _ = train_model(us_cnn,
                        us_train_loader, us_val_loader,
                        device, epochs=15, lr=1e-3)

us_gnn = GNNOnlyClassifier(num_classes=us_num_classes,
                           hidden_dim=64, dropout=0.3)
us_gnn, _ = train_model(us_gnn,
                        us_train_loader, us_val_loader,
                        device, epochs=15, lr=1e-3)

# ---- Threshold tuning on UrbanSound8K ----
thresholds = np.linspace(0.05, 0.5, 10)

print("\nTuning threshold for US8K HYBRID...")
best_thr_h_us, best_metrics_h_us = None, None
best_macro_f1 = -1
for thr in thresholds:
    m = evaluate_full(us_hybrid, us_val_loader, device, threshold=thr)
    print(f"thr={thr:.2f} -> macro_f1={m['macro_f1']:.4f}, micro_f1={m['micro_f1']:.4f}")
    if m['macro_f1'] > best_macro_f1:
        best_macro_f1 = m['macro_f1']
        best_thr_h_us = thr
        best_metrics_h_us = m

print("\n[US8K] Best HYBRID threshold on val:", best_thr_h_us, best_metrics_h_us)

print("\nTuning threshold for US8K CNN...")
best_thr_c_us, best_metrics_c_us = None, None
best_macro_f1 = -1
for thr in thresholds:
    m = evaluate_full(us_cnn, us_val_loader, device, threshold=thr)
    print(f"thr={thr:.2f} -> macro_f1={m['macro_f1']:.4f}, micro_f1={m['micro_f1']:.4f}")
    if m['macro_f1'] > best_macro_f1:
        best_macro_f1 = m['macro_f1']
        best_thr_c_us = thr
        best_metrics_c_us = m

print("\n[US8K] Best CNN threshold on val:", best_thr_c_us, best_metrics_c_us)

print("\nTuning threshold for US8K GNN...")
best_thr_g_us, best_metrics_g_us = None, None
best_macro_f1 = -1
for thr in thresholds:
    m = evaluate_full(us_gnn, us_val_loader, device, threshold=thr)
    print(f"thr={thr:.2f} -> macro_f1={m['macro_f1']:.4f}, micro_f1={m['micro_f1']:.4f}")
    if m['macro_f1'] > best_macro_f1:
        best_macro_f1 = m['macro_f1']
        best_thr_g_us = thr
        best_metrics_g_us = m

print("\n[US8K] Best GNN threshold on val:", best_thr_g_us, best_metrics_g_us)

# ---- Final test metrics on UrbanSound8K with tuned thresholds ----
print("\nUS8K HYBRID (test, tuned):")
print(evaluate_full(us_hybrid, us_test_loader, device, threshold=best_thr_h_us))

print("\nUS8K CNN (test, tuned):")
print(evaluate_full(us_cnn, us_test_loader, device, threshold=best_thr_c_us))

print("\nUS8K GNN (test, tuned):")
print(evaluate_full(us_gnn, us_test_loader, device, threshold=best_thr_g_us))

UrbanSound8K num_classes: 10
UrbanSound8K labels: ['air_conditioner' 'car_horn' 'children_playing' 'dog_bark' 'drilling'
 'engine_idling' 'gun_shot' 'jackhammer' 'siren' 'street_music']
Epoch 01 | train_loss=0.2903 | val_loss=0.2219 | val_macroF1=0.3254
Epoch 02 | train_loss=0.1983 | val_loss=0.1881 | val_macroF1=0.5020
Epoch 03 | train_loss=0.1660 | val_loss=0.1394 | val_macroF1=0.7275
Epoch 04 | train_loss=0.1421 | val_loss=0.1437 | val_macroF1=0.6848
Epoch 05 | train_loss=0.1259 | val_loss=0.2339 | val_macroF1=0.5241
Epoch 06 | train_loss=0.1112 | val_loss=0.1344 | val_macroF1=0.7425
Epoch 07 | train_loss=0.0995 | val_loss=0.1031 | val_macroF1=0.8012
Epoch 08 | train_loss=0.0922 | val_loss=0.1045 | val_macroF1=0.8118
Epoch 09 | train_loss=0.0817 | val_loss=0.0982 | val_macroF1=0.8193
Epoch 10 | train_loss=0.0756 | val_loss=0.1471 | val_macroF1=0.7439
Epoch 11 | train_loss=0.0706 | val_loss=0.0956 | val_macroF1=0.8218
Epoch 12 | train_loss=0.0686 | val_loss=0.0963 | val_macroF1=0.826

In [None]:
def run_experiments_on_npz(npz_path, dataset_name, batch_size=32,
                           epochs=15, lr=1e-3):
    """
    Load an MFCC .npz file (X, Y, label_vocab),
    train CNN, Hybrid, and GNN-only models,
    tune thresholds on the validation set,
    and print test metrics.
    """
    print("\n" + "="*80)
    print("Running experiments on", dataset_name)
    print("Loading:", npz_path)
    print("="*80)

    data = np.load(npz_path, allow_pickle=True)
    X = data["X"]
    Y = data["Y"]
    label_vocab = data["label_vocab"]

    num_classes = Y.shape[1]
    print("X shape:", X.shape, "Y shape:", Y.shape, "#classes:", num_classes)
    print("Labels:", label_vocab)

    # Dataset & split
    full_dataset = MFCCDataset(X, Y)
    N = len(full_dataset)
    n_train = int(0.7 * N)
    n_val   = int(0.15 * N)
    n_test  = N - n_train - n_val

    train_ds, val_ds, test_ds = random_split(full_dataset, [n_train, n_val, n_test])

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=2)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size*2, shuffle=False, num_workers=2)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size*2, shuffle=False, num_workers=2)

    # ---------------------------
    # HYBRID model
    # ---------------------------
    hybrid = HybridRCAModel(num_classes=num_classes, hidden_dim=64, dropout=0.3)
    hybrid, _ = train_model(hybrid, train_loader, val_loader, device,
                            epochs=epochs, lr=lr)

    print("\nTuning threshold for HYBRID on validation set...")
    best_thr_hybrid, hybrid_val_metrics = find_best_threshold(
        hybrid, val_loader, device
    )
    hybrid_test_metrics = evaluate_full(
        hybrid, test_loader, device, threshold=best_thr_hybrid
    )
    print("\n[HYBRID] Best threshold:", best_thr_hybrid)
    print("[HYBRID] Test metrics:", hybrid_test_metrics)

    # ---------------------------
    # CNN-only baseline
    # ---------------------------
    cnn_baseline = CNNAudioClassifier(num_classes=num_classes,
                                      hidden_dim=128, dropout=0.3)
    cnn_baseline, _ = train_model(cnn_baseline, train_loader, val_loader, device,
                                  epochs=epochs, lr=lr)

    print("\nTuning threshold for CNN baseline on validation set...")
    best_thr_cnn, cnn_val_metrics = find_best_threshold(
        cnn_baseline, val_loader, device
    )
    cnn_test_metrics = evaluate_full(
        cnn_baseline, test_loader, device, threshold=best_thr_cnn
    )
    print("\n[CNN] Best threshold:", best_thr_cnn)
    print("[CNN] Test metrics:", cnn_test_metrics)

    # ---------------------------
    # GNN-only baseline
    # ---------------------------
    gnn_only = GNNOnlyClassifier(num_classes=num_classes,
                                 hidden_dim=64, dropout=0.3)
    gnn_only, _ = train_model(gnn_only, train_loader, val_loader, device,
                              epochs=epochs, lr=lr)

    print("\nTuning threshold for GNN-only on validation set...")
    best_thr_gnn, gnn_val_metrics = find_best_threshold(
        gnn_only, val_loader, device
    )
    gnn_test_metrics = evaluate_full(
        gnn_only, test_loader, device, threshold=best_thr_gnn
    )
    print("\n[GNN] Best threshold:", best_thr_gnn)
    print("[GNN] Test metrics:", gnn_test_metrics)

    # ---------------------------
    # Ensemble (CNN + HYBRID)
    # ---------------------------
    print("\nTuning threshold for ENSEMBLE (CNN + HYBRID) on validation set...")
    best_thr_ens, ens_val_metrics = find_best_threshold_ensemble(
        cnn_baseline, hybrid, val_loader, device
    )
    ensemble_test_metrics = evaluate_ensemble_full(
        cnn_baseline, hybrid, test_loader, device, threshold=best_thr_ens
    )
    print("\n[ENSEMBLE] Best threshold:", best_thr_ens)
    print("[ENSEMBLE] Test metrics:", ensemble_test_metrics)

    return {
        "hybrid":   (best_thr_hybrid, hybrid_test_metrics),
        "cnn":      (best_thr_cnn, cnn_test_metrics),
        "gnn":      (best_thr_gnn, gnn_test_metrics),
        "ensemble": (best_thr_ens, ensemble_test_metrics),
    }

In [None]:
def find_best_threshold(model, val_loader, device,
                        thresholds=None, verbose=True):
    """
    Try different probability thresholds and pick the one that gives the
    best macro-F1 on the validation set.
    """
    if thresholds is None:
        thresholds = np.linspace(0.05, 0.5, 10)  # you can tweak this

    best_thr = None
    best_metrics = None
    best_macro_f1 = -1.0

    for thr in thresholds:
        metrics = evaluate_full(model, val_loader, device, threshold=thr)
        if verbose:
            print(
                f"thr={thr:.2f} -> "
                f"macro_f1={metrics['macro_f1']:.4f}, "
                f"micro_f1={metrics['micro_f1']:.4f}, "
                f"macro_AP={metrics['macro_AP']:.4f}"
            )

        if metrics["macro_f1"] > best_macro_f1:
            best_macro_f1 = metrics["macro_f1"]
            best_thr = thr
            best_metrics = metrics

    if verbose:
        print("\nBest threshold on val set:")
        print(
            f"thr={best_thr:.2f}, "
            f"macro_f1={best_metrics['macro_f1']:.4f}, "
            f"micro_f1={best_metrics['micro_f1']:.4f}, "
            f"macro_AP={best_metrics['macro_AP']:.4f}"
        )

    return best_thr, best_metrics

In [None]:
hybrid.to(device)
print("\nTuning threshold for HYBRID on validation set...")
best_thr_hybrid, hybrid_val_metrics = find_best_threshold(
    hybrid, val_loader, device
)
hybrid_test_metrics = evaluate_full(
    hybrid, test_loader, device, threshold=best_thr_hybrid
)
print("\nHYBRID model (test, tuned):")
print("Best threshold:", best_thr_hybrid)
print(hybrid_test_metrics)

# -----------------------------------
# CNN baseline: tuned threshold eval
# -----------------------------------
cnn_baseline.to(device)
print("\nTuning threshold for CNN baseline on validation set...")
best_thr_cnn, cnn_val_metrics = find_best_threshold(
    cnn_baseline, val_loader, device
)
cnn_test_metrics = evaluate_full(
    cnn_baseline, test_loader, device, threshold=best_thr_cnn
)
print("\nCNN baseline (test, tuned):")
print("Best threshold:", best_thr_cnn)
print(cnn_test_metrics)

# -----------------------------------
# GNN-only: tuned threshold eval
# -----------------------------------
gnn_only.to(device)
print("\nTuning threshold for GNN-only on validation set...")
best_thr_gnn, gnn_val_metrics = find_best_threshold(
    gnn_only, val_loader, device
)
gnn_test_metrics = evaluate_full(
    gnn_only, test_loader, device, threshold=best_thr_gnn
)
print("\nGNN-only baseline (test, tuned):")
print("Best threshold:", best_thr_gnn)
print(gnn_test_metrics)


Tuning threshold for HYBRID on validation set...
thr=0.05 -> macro_f1=0.2253, micro_f1=0.3998, macro_AP=0.2326
thr=0.10 -> macro_f1=0.2379, micro_f1=0.4896, macro_AP=0.2326
thr=0.15 -> macro_f1=0.2327, micro_f1=0.5460, macro_AP=0.2326
thr=0.20 -> macro_f1=0.1859, micro_f1=0.5674, macro_AP=0.2326
thr=0.25 -> macro_f1=0.1677, micro_f1=0.5762, macro_AP=0.2326
thr=0.30 -> macro_f1=0.1565, micro_f1=0.5743, macro_AP=0.2326
thr=0.35 -> macro_f1=0.1404, micro_f1=0.5716, macro_AP=0.2326
thr=0.40 -> macro_f1=0.1246, micro_f1=0.5627, macro_AP=0.2326
thr=0.45 -> macro_f1=0.1095, micro_f1=0.5565, macro_AP=0.2326
thr=0.50 -> macro_f1=0.0965, micro_f1=0.5422, macro_AP=0.2326

Best threshold on val set:
thr=0.10, macro_f1=0.2379, micro_f1=0.4896, macro_AP=0.2326

HYBRID model (test, tuned):
Best threshold: 0.1
{'loss': 0.15142579548599217, 'macro_f1': 0.24913034279387086, 'micro_f1': 0.5082841898343162, 'subset_accuracy': 0.15057680631451123, 'macro_AP': np.float64(0.24633068522937504)}

Tuning thres

In [None]:
def evaluate_ensemble_full(model_a, model_b, loader, device, threshold=0.5):
    model_a.eval()
    model_b.eval()
    all_targets = []
    all_probs = []

    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device)

            logits_a = model_a(xb)
            logits_b = model_b(xb)

            probs_a = torch.sigmoid(logits_a)
            probs_b = torch.sigmoid(logits_b)
            probs = 0.5 * (probs_a + probs_b)

            all_targets.append(yb.cpu().numpy())
            all_probs.append(probs.cpu().numpy())

    y_true = np.concatenate(all_targets, axis=0)
    y_scores = np.concatenate(all_probs, axis=0)
    y_pred = (y_scores >= threshold).astype(np.float32)

    macro_f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)
    micro_f1 = f1_score(y_true, y_pred, average="micro", zero_division=0)
    subset_accuracy = accuracy_score(y_true, y_pred)
    macro_AP = average_precision_score(y_true, y_scores, average="macro")

    return {
        "macro_f1": macro_f1,
        "micro_f1": micro_f1,
        "subset_accuracy": subset_accuracy,
        "macro_AP": macro_AP,
    }

ensemble_test_metrics = evaluate_ensemble_full(cnn_baseline.to(device),
                                               hybrid.to(device),
                                               test_loader,
                                               device)
print("Ensemble (test):", ensemble_test_metrics)

Ensemble (test): {'macro_f1': 0.08791812727008022, 'micro_f1': 0.5509173092262696, 'subset_accuracy': 0.40619307832422585, 'macro_AP': np.float64(0.25068282209244286)}


In [None]:
# 1) AudioSet (top-K). If you created a *_topK.npz* file, use that name.
results_audioset = run_experiments_on_npz(
    npz_path="audioset_balanced_mfcc128x128_topK.npz",
    dataset_name="AudioSet (top-20 labels)",
    batch_size=32,
    epochs=15,
    lr=1e-3
)

# 2) GTZAN (10 genres)
results_gtzan = run_experiments_on_npz(
    npz_path="gtzan_mfcc128x128.npz",
    dataset_name="GTZAN (10 genres)",
    batch_size=32,
    epochs=15,
    lr=1e-3
)

# 3) UrbanSound8K (10 classes)
results_urban = run_experiments_on_npz(
    npz_path="urbansound8k_mfcc128x128.npz",
    dataset_name="UrbanSound8K (10 classes)",
    batch_size=32,
    epochs=15,
    lr=1e-3
)


Running experiments on AudioSet (top-20 labels)
Loading: audioset_balanced_mfcc128x128_topK.npz
X shape: (10975, 128, 128) Y shape: (10975, 20) #classes: 20
Labels: ['Music' 'Speech' 'Vehicle' 'Animal' 'Inside, small room'
 'Musical instrument' 'Singing' 'Domestic animals, pets' 'Guitar'
 'Plucked string instrument' 'Car' 'Dog' 'Percussion'
 'Wind instrument, woodwind instrument' 'Boat, Water vehicle' 'Water'
 'Outside, urban or manmade' 'Outside, rural or natural'
 'Brass instrument' 'Engine']
Epoch 01 | train_loss=0.2145 | val_loss=0.1672 | val_macroF1=0.0742
Epoch 02 | train_loss=0.1696 | val_loss=0.1601 | val_macroF1=0.0752
Epoch 03 | train_loss=0.1638 | val_loss=0.1594 | val_macroF1=0.0751
Epoch 04 | train_loss=0.1606 | val_loss=0.1619 | val_macroF1=0.0753
Epoch 05 | train_loss=0.1588 | val_loss=0.1562 | val_macroF1=0.0772
Epoch 06 | train_loss=0.1567 | val_loss=0.1565 | val_macroF1=0.0810
Epoch 07 | train_loss=0.1566 | val_loss=0.1561 | val_macroF1=0.0764
Epoch 08 | train_loss=0