In [1]:
import labeler
!pip install torch torchaudio torchvision librosa matplotlib soundfile

Collecting torch
  Obtaining dependency information for torch from https://files.pythonhosted.org/packages/81/b4/605ae4173aa37fb5aa14605d100ff31f4f5d49f617928c9f486bb3aaec08/torch-2.6.0-cp312-none-macosx_11_0_arm64.whl.metadata
  Using cached torch-2.6.0-cp312-none-macosx_11_0_arm64.whl.metadata (28 kB)
Collecting torchaudio
  Obtaining dependency information for torchaudio from https://files.pythonhosted.org/packages/ac/4a/d71b932bda4171970bdf4997541b5c778daa0e2967ed5009d207fca86ded/torchaudio-2.6.0-cp312-cp312-macosx_11_0_arm64.whl.metadata
  Using cached torchaudio-2.6.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (6.6 kB)
Collecting torchvision
  Obtaining dependency information for torchvision from https://files.pythonhosted.org/packages/6e/1b/28f527b22d5e8800184d0bc847f801ae92c7573a8c15979d92b7091c0751/torchvision-0.21.0-cp312-cp312-macosx_11_0_arm64.whl.metadata
  Using cached torchvision-0.21.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (6.1 kB)
Collecting filelock (fro

In [2]:
import os
import glob
import librosa
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)

In [None]:
class AudioDataset(Dataset):
    """
    Custom PyTorch Dataset that:
      - Recursively scans a directory for audio files (wav or npy).
      - Parses labels from filename (basic approach).
      - Loads audio, converts to a Mel-spectrogram.
    """

    def __init__(self, data_dir, sample_rate=44100, n_mels=64, transform=None):
        super().__init__()
        self.data_dir = data_dir
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.transform = transform

        # Collect all .wav or .npy files in data_dir
        # wav_files = glob.glob(os.path.join(data_dir, "*.wav"))
        npy_files = glob.glob(os.path.join(data_dir, "*.npy"))
        self.audio_files = npy_files
        self.audio_files.sort()

        # In a real scenario, we have a separate label file or use a more robust approach
        # For simplicity, we parse the label from the filename structure, e.g.:
        #   "good_coin_123456.wav" -> "good_coin"

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        file_path = self.audio_files[idx]
        filename = os.path.basename(file_path)

        # Example label parsing:
        # split on '_' and take first 1 or 2 tokens
        # or parse "good_coin" until the next underscore
        # do what's appropriate for your naming
        label_str = filename.split("_")[0]  # e.g. "accept", "reject"

        # Convert label_str to a numeric class index
        if "accept" in label_str:
            label = 0
        else:
            label = 1

        # Load audio data
        if file_path.endswith(".wav"):
            # Load WAV
            y, sr = librosa.load(file_path, sr=self.sample_rate, mono=True)
        else:
            # Load NPY
            y = np.load(file_path)
            sr = self.sample_rate  # assume consistent sample rate

        # Convert to Mel-spectrogram
        # shape -> (n_mels, time)
        mel_spec = librosa.feature.melspectrogram(
            y=y, sr=sr, n_mels=self.n_mels, fmax=sr//2
        )
        # Convert to decibels
        mel_db = librosa.power_to_db(mel_spec, ref=np.max)

        # Optional: apply transform / augmentation
        if self.transform:
            mel_db = self.transform(mel_db)

        # Convert to float tensor
        mel_tensor = torch.tensor(mel_db, dtype=torch.float)

        # (n_mels, time) -> (1, n_mels, time) to match CNN [batch, channel, H, W]
        mel_tensor = mel_tensor.unsqueeze(0)

        return mel_tensor, label

In [None]:
ALL_TAGS = ["accept", "reject"]
ALL_TAGS.extend(labeler.AVAILABLE_TAGS)  # gets the rest of tags, coin, ring, foil, silver, gold, ringpull,...
tag2idx = {tag: idx for idx, tag in enumerate(ALL_TAGS)}


def parse_tags_from_filename(filename):
    """
    Example filename: "accept_coin_silver_123abc.wav"
    We'll parse out ["accept", "coin", "silver", "123abc"] from the base,
    and set those tags to 1 in a multi-hot label vector.
    """
    base, _ = os.path.splitext(filename)  # e.g. "accept_coin_silver_123abc"
    parts = base.split("_")  # ["accept", "coin", "silver", "123abc"]

    # Initialize multi-hot vector (all zeros)
    label_vec = [0] * len(ALL_TAGS)

    # For each part, if it's a recognized tag, set to 1
    for p in parts:
        if p in tag2idx:
            label_vec[tag2idx[p]] = 1

    return label_vec