In [1]:
from pathlib import Path
import os
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB
import numpy as np

DATA_ROOT = Path("/kaggle/input/music-instrunment-sounds-for-classification")  # adjust if needed
TRAIN_DIR = DATA_ROOT / "music_dataset"   # check exact name on Kaggle UI
VAL_RATIO = 0.2                   # use 20% of train as validation if no val folder

def build_file_list(root_dir):
    wav_paths = []
    labels = []
    class_names = sorted([d.name for d in root_dir.iterdir() if d.is_dir()])
    class_to_idx = {c: i for i, c in enumerate(class_names)}
    for cls in class_names:
        for wav in (root_dir / cls).glob("*.wav"):
            wav_paths.append(str(wav))
            labels.append(class_to_idx[cls])
    return wav_paths, labels, class_to_idx

all_paths, all_labels, class_to_idx = build_file_list(TRAIN_DIR)
print("Classes:", class_to_idx)
print("Total files:", len(all_paths))

# train/val split
from sklearn.model_selection import train_test_split

train_paths, val_paths, train_labels, val_labels = train_test_split(
    all_paths, all_labels, test_size=VAL_RATIO, stratify=all_labels, random_state=42
)

print("Train size:", len(train_paths), "Val size:", len(val_paths))


Classes: {'Accordion': 0, 'Acoustic_Guitar': 1, 'Banjo': 2, 'Bass_Guitar': 3, 'Clarinet': 4, 'Cymbals': 5, 'Dobro': 6, 'Drum_set': 7, 'Electro_Guitar': 8, 'Floor_Tom': 9, 'Harmonica': 10, 'Harmonium': 11, 'Hi_Hats': 12, 'Horn': 13, 'Keyboard': 14, 'Mandolin': 15, 'Organ': 16, 'Piano': 17, 'Saxophone': 18, 'Shakers': 19, 'Tambourine': 20, 'Trombone': 21, 'Trumpet': 22, 'Ukulele': 23, 'Violin': 24, 'cowbell': 25, 'flute': 26, 'vibraphone': 27}
Total files: 42311
Train size: 33848 Val size: 8463


In [2]:
SAMPLE_RATE = 16000
N_MELS = 128
N_FFT = 1024
HOP_LENGTH = 256
SPEC_LEN = 128
BATCH_SIZE = 32
NUM_WORKERS = 2

class InstrumentDataset(Dataset):
    def __init__(self, paths, labels):
        self.paths = paths
        self.labels = labels

        self.mel = MelSpectrogram(
            sample_rate=SAMPLE_RATE,
            n_fft=N_FFT,
            hop_length=HOP_LENGTH,
            n_mels=N_MELS,
            center=True,
            power=2.0,
        )
        self.to_db = AmplitudeToDB(stype="power")

    def __len__(self):
        return len(self.paths)

    def _fix_length(self, spec):
        if spec.size(1) < SPEC_LEN:
            pad = SPEC_LEN - spec.size(1)
            spec = torch.nn.functional.pad(spec, (0, pad))
        elif spec.size(1) > SPEC_LEN:
            spec = spec[:, :SPEC_LEN]
        return spec

    def __getitem__(self, idx):
        wav_path = self.paths[idx]
        label = self.labels[idx]

        waveform, sr = torchaudio.load(wav_path)
        if sr != SAMPLE_RATE:
            waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE)
        if waveform.size(0) > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        spec = self.mel(waveform)
        spec = self.to_db(spec).squeeze(0)
        spec = self._fix_length(spec)

        mean = spec.mean()
        std = spec.std() + 1e-6
        spec = (spec - mean) / std
        spec = spec.unsqueeze(0)   # (1, 128, 128)

        return spec.float(), label

train_ds = InstrumentDataset(train_paths, train_labels)
val_ds   = InstrumentDataset(val_paths,   val_labels)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True)

x, y = next(iter(train_loader))
print("Batch shape:", x.shape)
print("Labels:", y[:8])


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Batch shape: torch.Size([32, 1, 128, 128])
Labels: tensor([15, 21, 16,  3,  0,  1,  3, 26])


In [3]:
NUM_CLASSES = len(class_to_idx)
idx_to_class = {v: k for k, v in class_to_idx.items()}
print("idx_to_class:", idx_to_class)

idx_to_class: {0: 'Accordion', 1: 'Acoustic_Guitar', 2: 'Banjo', 3: 'Bass_Guitar', 4: 'Clarinet', 5: 'Cymbals', 6: 'Dobro', 7: 'Drum_set', 8: 'Electro_Guitar', 9: 'Floor_Tom', 10: 'Harmonica', 11: 'Harmonium', 12: 'Hi_Hats', 13: 'Horn', 14: 'Keyboard', 15: 'Mandolin', 16: 'Organ', 17: 'Piano', 18: 'Saxophone', 19: 'Shakers', 20: 'Tambourine', 21: 'Trombone', 22: 'Trumpet', 23: 'Ukulele', 24: 'Violin', 25: 'cowbell', 26: 'flute', 27: 'vibraphone'}


In [6]:
import os
from pathlib import Path
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import torchaudio
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB


In [11]:
import torch

# Choose GPU if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Also make sure NUM_CLASSES is defined
NUM_CLASSES = len(class_to_idx)   # if you built class_to_idx earlier
print("NUM_CLASSES:", NUM_CLASSES)
# Training hyperparameters
LR = 1e-3          # learning rate
NUM_EPOCHS = 10    # or any value you want
BATCH_SIZE = 32    # make sure this matches your DataLoader


Using device: cuda
NUM_CLASSES: 28


In [12]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=11):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 64x64

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 32x32

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 16x16

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

model = SimpleCNN(NUM_CLASSES).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

print(model)


SimpleCNN(
  (features): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (12): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): BatchNorm2d(256, eps=1e-05, momentum=0.1, af

In [None]:
def run_epoch(loader, model, optimizer=None):
    is_train = optimizer is not None
    model.train(is_train)

    total_loss, total_correct, total_count = 0.0, 0, 0

    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)

        with torch.set_grad_enabled(is_train):
            logits = model(xb)
            loss = criterion(logits, yb)

            if is_train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        total_loss += loss.item() * xb.size(0)
        preds = logits.argmax(dim=1)
        total_correct += (preds == yb).sum().item()
        total_count += xb.size(0)

    avg_loss = total_loss / total_count
    acc = total_correct / total_count
    return avg_loss, acc

best_val_acc = 0.0
for epoch in range(1, NUM_EPOCHS + 1):
    train_loss, train_acc = run_epoch(train_loader, model, optimizer)
    val_loss, val_acc = run_epoch(val_loader, model, optimizer=None)

    print(f"Epoch {epoch:02d}: "
          f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
          f"val_loss={val_loss:.4f} acc={val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")
        print("  → New best model saved with val_acc:", best_val_acc)

Epoch 01: train_loss=0.3258 acc=0.9045 | val_loss=0.2433 acc=0.9219
  → New best model saved with val_acc: 0.9218953089920832
Epoch 02: train_loss=0.2215 acc=0.9333 | val_loss=0.2236 acc=0.9253
  → New best model saved with val_acc: 0.9253219898381189
Epoch 03: train_loss=0.1726 acc=0.9471 | val_loss=0.1342 acc=0.9602
  → New best model saved with val_acc: 0.9601796053408956
Epoch 04: train_loss=0.1368 acc=0.9583 | val_loss=0.1524 acc=0.9526
Epoch 05: train_loss=0.1222 acc=0.9620 | val_loss=0.1055 acc=0.9680
  → New best model saved with val_acc: 0.967978258300839
Epoch 06: train_loss=0.1019 acc=0.9688 | val_loss=0.2273 acc=0.9259
Epoch 07: train_loss=0.0930 acc=0.9717 | val_loss=0.0858 acc=0.9751
  → New best model saved with val_acc: 0.9750679428098783
Epoch 08: train_loss=0.0816 acc=0.9744 | val_loss=0.1113 acc=0.9660
Epoch 09: train_loss=0.0726 acc=0.9773 | val_loss=0.1219 acc=0.9664


In [None]:
# Save final PyTorch model
torch.save(model.state_dict(), "nsynth_instrument_family_cnn.pth")

# Export to ONNX
model.eval()
dummy_input = torch.randn(1, 1, N_MELS, SPEC_LEN, device=device)

onnx_path = "nsynth_instrument_family_cnn.onnx"
torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    input_names=["input"],
    output_names=["logits"],
    dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}},
    opset_version=17,
)
print("ONNX model exported to", onnx_path)


In [None]:
import onnxruntime as ort

class NsynthOnnxWrapper:
    def __init__(self, onnx_model_path):
        self.session = ort.InferenceSession(
            onnx_model_path,
            providers=["CPUExecutionProvider"]
        )
        self.input_name = self.session.get_inputs()[0].name
        self.output_name = self.session.get_outputs()[0].name

        # Reuse same transforms as Dataset
        self.mel = MelSpectrogram(
            sample_rate=SAMPLE_RATE,
            n_fft=N_FFT,
            hop_length=HOP_LENGTH,
            n_mels=N_MELS,
            center=True,
            power=2.0,
        )
        self.to_db = AmplitudeToDB(stype="power")

    def preprocess_wav(self, wav_path):
        waveform, sr = torchaudio.load(wav_path)
        if sr != SAMPLE_RATE:
            waveform = torchaudio.functional.resample(
                waveform, orig_freq=sr, new_freq=SAMPLE_RATE
            )
        if waveform.size(0) > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)

        spec = self.mel(waveform)
        spec = self.to_db(spec).squeeze(0)
        # same length handling as Dataset
        if spec.size(1) < SPEC_LEN:
            pad = SPEC_LEN - spec.size(1)
            spec = torch.nn.functional.pad(spec, (0, pad))
        elif spec.size(1) > SPEC_LEN:
            spec = spec[:, :SPEC_LEN]

        mean = spec.mean()
        std = spec.std() + 1e-6
        spec = (spec - mean) / std
        spec = spec.unsqueeze(0).unsqueeze(0)  # (1,1,128,128)

        return spec.numpy().astype("float32")

    def predict(self, wav_path):
        x = self.preprocess_wav(wav_path)
        logits = self.session.run([self.output_name], {self.input_name: x})[0]
        probs = torch.softmax(torch.from_numpy(logits), dim=1).numpy()[0]
        pred_idx = int(np.argmax(probs))
        pred_name = INSTR_FAMILY_MAP[pred_idx]
        return pred_idx, pred_name, probs

#Example usage (outside Streamlit):
wrapper = NsynthOnnxWrapper("nsynth_instrument_family_cnn.onnx")
idx, name, probs = wrapper.predict("/path/to/some.wav")
print(idx, name)
