In [7]:
import torch
import librosa
import numpy as np
from src.model import HybridCNNGRUWithAttention  # Assuming model class is defined in model.py
import os
from skimage.transform import resize
from src.dataset import normalize_peak

In [9]:
def preprocess_audio_file(audio_path, resized_shape=(150, 150)):
    spectrograms = []
        
    if audio_path.lower().endswith(('.wav', '.mp3')):
        audio, rate = librosa.load(audio_path, sr=None)
        audio = normalize_peak(audio)  # Apply peak normalization
        
        duration_of_chunk = 4  # seconds
        overlap = 2  # seconds
        samples_per_chunk = duration_of_chunk * rate
        samples_overlap = overlap * rate
        total_chunks = int(np.ceil((len(audio) - samples_per_chunk) / (samples_per_chunk - samples_overlap))) + 1
        
        for chunk_number in range(total_chunks):
            start_sample = chunk_number * (samples_per_chunk - samples_overlap)
            end_sample = start_sample + samples_per_chunk
            audio_chunk = audio[start_sample:end_sample]
            mel_spect = librosa.feature.melspectrogram(y=audio_chunk, sr=rate)
            resized_mel_spect = resize(np.expand_dims(mel_spect, axis=-1), resized_shape)
            spectrograms.append(resized_mel_spect)
    
    return np.array(spectrograms)

In [14]:
# Define class names
class_names = ['Blues', 'Classical', 'Country', 'Disco', 'Hip-hop', 'Jazz', 'Metal', 'Pop', 'Reggae', 'Rock']

# Load and preprocess audio
audio_path = 'data/test_files/pop4_billiejean.mp3'

spectrograms = preprocess_audio_file(audio_path)

# Convert to tensor and ensure correct dimensions
X = torch.FloatTensor(spectrograms).permute(0, 3, 1, 2)

# Load trained model
model_path = 'output/saved_models/best_gru_with_attention_model_1.pth'
model = HybridCNNGRUWithAttention()  # Initialize model with same architecture
model.load_state_dict(torch.load(model_path))
model.eval()

# Make prediction
with torch.no_grad():
    output = model(X)
    predictions = torch.sigmoid(output)
    
# Print prediction probabilities for each class
for i, prob in enumerate(predictions[0]):
    print(f"{class_names[i]}: {prob.item():.4f}")

Blues: 0.0262
Classical: 0.0124
Country: 0.0384
Disco: 0.0086
Hip-hop: 0.9740
Jazz: 0.0290
Metal: 0.7356
Pop: 0.9955
Reggae: 0.0133
Rock: 0.0233
