Evaluate Wav2Vec2 Model#
#
This notebook generates a confusion matrix for the Wav2Vec2 model trained on the RAVDESS dataset.#
#


In [4]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
from torch.utils.data import DataLoader
import torch
#from emotion_classification import RAVDESSDataset

In [5]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
from torch.utils.data import DataLoader
import torch
import torchaudio
import os
from omegaconf import OmegaConf
#
# Load config#
cfg = OmegaConf.load('../configs/config.yaml')
MODEL_PATH = cfg.model.save_path
VAL_DIR = cfg.data.val_dir
SAMPLE_RATE = cfg.data.sample_rate
MAX_LENGTH = cfg.data.max_length
EMOTIONS = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'surprise', 'disgust']
#
# RAVDESS Dataset class#
class RAVDESSDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, processor, max_length=5.0, sample_rate=16000):
        self.data_dir = data_dir
        self.processor = processor
        self.max_length = max_length
        self.sample_rate = sample_rate
        self.files = []
        self.labels = []
        
        for file in os.listdir(data_dir):
            if file.endswith('.wav'):
                self.files.append(os.path.join(data_dir, file))
                emotion_id = int(file.split('-')[2]) - 1
                self.labels.append(emotion_id)
    #
    def __len__(self):
        return len(self.files)
    #
    def __getitem__(self, idx):
        file_path = self.files[idx]
        label = self.labels[idx]
        
        waveform, sr = torchaudio.load(file_path)
        if sr != self.sample_rate:
            waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform)
        #
        max_samples = int(self.max_length * self.sample_rate)
        if waveform.shape[1] > max_samples:#
            waveform = waveform[:, :max_samples]
        else:#
            padding = max_samples - waveform.shape[1]
            waveform = torch.nn.functional.pad(waveform, (0, padding))
        
        # inputs = self.processor(waveform.squeeze().numpy(), sampling_rate=self.sample_rate, return_tensors=\"pt\", padding=True)
        inputs = self.processor(waveform.squeeze().numpy(), sampling_rate=self.sample_rate, return_tensors='pt', padding=True)
        return {
            'input_values': inputs.input_values.squeeze(),
            'attention_mask': inputs.attention_mask.squeeze(),
            'labels': label
        }#
#
# Load model and processor
processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH)
model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_PATH)
model.eval()
#
# Load validation dataset
val_dataset = RAVDESSDataset(VAL_DIR, processor, MAX_LENGTH, SAMPLE_RATE)
val_loader = DataLoader(val_dataset, batch_size=8)

# Get predictions
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
all_preds, all_labels = [], []
with torch.no_grad():
    for batch in val_loader:
        input_values = batch['input_values'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_values, attention_mask=attention_mask)
        all_preds.extend(outputs.logits.argmax(dim=-1).cpu().numpy())
        all_labels.extend(labels.cpu().numpy())#
#
# Plot confusion matrix
cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=EMOTIONS)
disp.plot(cmap=plt.cm.Blues)
plt.title('Confusion Matrix - Wav2Vec2')
plt.show()

OSError: Can't load feature extractor for '../models/wav2vec2-emotion'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure '../models/wav2vec2-emotion' is the correct path to a directory containing a preprocessor_config.json file