In [4]:
import torch
import torchvggish
import librosa
import os
import numpy as np
from panns_inference import AudioTagging
from pydub import AudioSegment

In [5]:
# Initialize VGGish model and AudioSet classifier
vggish_model = torchvggish.vggish()
vggish_model.eval()
audio_tagging = AudioTagging(checkpoint_path=None)  # Use the default pretrained checkpoint

Checkpoint path: C:\Users\krzyzehj/panns_data/Cnn14_mAP=0.431.pth


  checkpoint = torch.load(checkpoint_path, map_location=self.device)


Using CPU.


In [6]:
# Function to extract VGGish embeddings
def extract_vggish_embeddings(wav_path):
    try:
        # Load and resample audio to 16kHz
        y, sr = librosa.load(wav_path, sr=16000)
        
        # Generate Mel spectrogram
        mel_spectrogram = librosa.feature.melspectrogram(
            y=y, sr=sr, n_mels=64, fmax=8000
        )
        mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max)

        # Ensure Mel spectrogram shape is (64, 96) as expected by VGGish
        if mel_spectrogram_db.shape != (64, 96):
            from scipy.ndimage import zoom
            mel_spectrogram_db = zoom(mel_spectrogram_db, (64 / mel_spectrogram_db.shape[0], 96 / mel_spectrogram_db.shape[1]))

        # Convert to PyTorch tensor and add batch and channel dimensions
        mel_spectrogram_db = torch.tensor(mel_spectrogram_db, dtype=torch.float32).unsqueeze(0).unsqueeze(0)

        # Extract VGGish embeddings
        with torch.no_grad():
            embeddings = vggish_model(mel_spectrogram_db)
        
        return embeddings.squeeze().numpy()
    
    except Exception as e:
        print(f"Error processing {wav_path}: {e}")
        return None

In [23]:
# Function to classify instruments using raw audio input
def classify_instruments(wav_path):
    try:
        # Load and resample the audio file to 32kHz (required by the model)
        y, sr = librosa.load(wav_path, sr=32000)
        audio_tensor = torch.tensor(y).float().unsqueeze(0)  # Shape: [1, audio_length]

        # Perform inference
        with torch.no_grad():
            result = audio_tagging.inference(audio_tensor)

        # Get the predicted tags
        clipwise_output = np.array(result[0]).flatten()
        instrument_predictions = []

        # Extract instrument labels based on high confidence predictions
        for i, confidence in enumerate(clipwise_output):
            # print(f"Instrument: {audio_tagging.labels[i]} - Confidence rating: {confidence}")
            if confidence > 0.1:  # Use a threshold of 0.05 for prediction
                label = audio_tagging.labels[i]
                if label in [
                    "Guitar", "Bass guitar", "Violin", "Cello", "Flute",
                    "Clarinet", "Saxophone", "Trumpet", "Piano", "Drum",
                    "Cymbal", "Organ"
                ]:
                    instrument_predictions.append(label)

        return instrument_predictions
    
    except Exception as e:
        print(f"Error processing {wav_path}: {e}")
        return []

In [30]:
# Function to split up .wav file into five second increments and send these to the classifier
def split_wav(wav_path):
    output_folder = r"Data/wav/wav_split/"
    chunk_length_ms = 2000

    os.makedirs(output_folder, exist_ok=True)

    audio = AudioSegment.from_wav(wav_path)

    total_length = len(audio)
    num_chunks = total_length // chunk_length_ms

    instrument_list = []
    for i in range(num_chunks + 1):
        start_time = i * chunk_length_ms
        end_time = start_time + chunk_length_ms

        chunk = audio[start_time:end_time]

        if len(chunk) == 0 or chunk.dBFS < -60:
            # print(f"Skipping chunk {i + 1} due to silence or zero length.")
            continue

        chunk_filename = os.path.join(output_folder, f"chunk_{i + 1}.wav")
        chunk.export(chunk_filename, format="wav")
        
        chunk_instruments = classify_instruments(chunk_filename)

        # print(f"Chunk {i + 1} instruments: {chunk_instruments}")

        for instrument in chunk_instruments:
            if instrument not in instrument_list:
                instrument_list.append(instrument)

        # chunk_filename = os.path.join(output_folder, f"chunk_{i + 1}.wav")
        # chunk.export(chunk_filename, format="wav")
    
    # print("Splitting complete")

    return instrument_list

In [None]:
# Example Usage
TEST_PATH = r"Data/wav/genres_original/classical/"
wav_files = os.listdir(TEST_PATH)
for file in wav_files:
    instruments = split_wav(TEST_PATH + file)
    print(f"Instruments in file {TEST_PATH + file} are... \n {instruments}")

Instruments in file Data/wav/genres_original/classical/008RKiNmjW5Lb6Ocumq6MA.wav are... 
 ['Cello', 'Piano']
Instruments in file Data/wav/genres_original/classical/04eShjKTWijeJJqGnhxpYK.wav are... 
 ['Cello']
Instruments in file Data/wav/genres_original/classical/05rNWKxli5goHcA4e77sGC.wav are... 
 ['Piano']
Instruments in file Data/wav/genres_original/classical/06am46cX3Z6YlSsg0TyVHA.wav are... 
 ['Cello']
Instruments in file Data/wav/genres_original/classical/07xafomqQcYmFJbr4jpfHa.wav are... 
 ['Cello']
Instruments in file Data/wav/genres_original/classical/086sjLPEqdKBgTxbTeCLCv.wav are... 
 ['Piano', 'Guitar']
Instruments in file Data/wav/genres_original/classical/0Gef573AJfARbMuQSoCy2r.wav are... 
 ['Piano', 'Cello']
Instruments in file Data/wav/genres_original/classical/0Gh45IbIKOG9IucFfrZqLT.wav are... 
 ['Cello']
Instruments in file Data/wav/genres_original/classical/0gsBQy8Q0eWlR67xDmoWFw.wav are... 
 ['Piano']
Instruments in file Data/wav/genres_original/classical/0iEIVX3D