In [2]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, accuracy_score, precision_recall_fscore_support, ConfusionMatrixDisplay
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
import torch
import torchaudio
import librosa
from torch.utils.data import DataLoader, Dataset
import os
from omegaconf import OmegaConf
import pandas as pd

# Load configuration
cfg = OmegaConf.load('../configs/config.yaml')
WAV2VEC2_PATH = cfg.model.save_path
CNN_PATH = os.path.join(cfg.model.cnn_save_path, 'cnn_model.pth')
VAL_DIR = cfg.data.val_dir
SAMPLE_RATE = cfg.data.sample_rate
MAX_LENGTH = cfg.data.max_length
N_MELS = cfg.data.n_mels
EMOTIONS = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprise']

# CNN Model Definition
class CNN(torch.nn.Module):
    def __init__(self, num_classes=8):
        super(CNN, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.fc1 = torch.nn.Linear(32 * 32 * 39, 128)  # Adjust based on input size
        self.dropout = torch.nn.Dropout(0.5)
        self.fc2 = torch.nn.Linear(128, num_classes)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Dataset for Wav2Vec2
class RAVDESSDatasetWav2Vec2(Dataset):
    def __init__(self, data_dir, processor, max_length=5.0, sample_rate=16000):
        self.data_dir = data_dir
        self.processor = processor
        self.max_length = max_length
        self.sample_rate = sample_rate
        self.files = []
        self.labels = []
        
        for file in os.listdir(data_dir):
            if file.endswith('.wav'):
                self.files.append(os.path.join(data_dir, file))
                emotion_id = int(file.split('-')[2]) - 1
                self.labels.append(emotion_id)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path = self.files[idx]
        label = self.labels[idx]
        
        waveform, sr = torchaudio.load(file_path)
        if sr != self.sample_rate:
            waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform)
        
        max_samples = int(self.max_length * self.sample_rate)
        if waveform.shape[1] > max_samples:
            waveform = waveform[:, :max_samples]
        else:
            padding = max_samples - waveform.shape[1]
            waveform = torch.nn.functional.pad(waveform, (0, padding))
        
        inputs = self.processor(waveform.squeeze().numpy(), sampling_rate=self.sample_rate, return_tensors='pt', padding=True)
        return {
            'input_values': inputs.input_values.squeeze(),
            'attention_mask': inputs.attention_mask.squeeze(),
            'labels': label
        }

# Dataset for CNN
class RAVDESSMelDataset(Dataset):
    def __init__(self, data_dir, sample_rate=16000, n_mels=128, max_length=5.0):
        self.data_dir = data_dir
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.max_length = max_length
        self.files = []
        self.labels = []

        for file in os.listdir(data_dir):
            if file.endswith('.wav'):
                self.files.append(os.path.join(data_dir, file))
                emotion_id = int(file.split('-')[2]) - 1
                self.labels.append(emotion_id)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path = self.files[idx]
        label = self.labels[idx]

        signal, sr = librosa.load(file_path, sr=self.sample_rate)
        max_samples = int(self.max_length * self.sample_rate)
        if len(signal) > max_samples:
            signal = signal[:max_samples]
        else:
            signal = np.pad(signal, (0, max_samples - len(signal)), 'constant')

        mel_spec = librosa.feature.melspectrogram(y=signal, sr=self.sample_rate, n_mels=self.n_mels)
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)

        mel_spec_db = torch.tensor(mel_spec_db, dtype=torch.float32).unsqueeze(0)
        return {
            'input': mel_spec_db,
            'label': torch.tensor(label, dtype=torch.long)
        }

# Function to compute metrics
def compute_metrics(labels, preds):
    accuracy = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

# Function to plot confusion matrix
def plot_confusion_matrix(labels, preds, title, emotions):
    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=emotions, yticklabels=emotions)
    plt.title(title)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

# Load Wav2Vec2 model and processor
processor = Wav2Vec2Processor.from_pretrained(WAV2VEC2_PATH)
wav2vec2_model = Wav2Vec2ForSequenceClassification.from_pretrained(WAV2VEC2_PATH)
wav2vec2_model.eval()

# Load CNN model
cnn_model = CNN(num_classes=8)
cnn_model.load_state_dict(torch.load(CNN_PATH))
cnn_model.eval()

# Load validation datasets
wav2vec2_val_dataset = RAVDESSDatasetWav2Vec2(VAL_DIR, processor, MAX_LENGTH, SAMPLE_RATE)
cnn_val_dataset = RAVDESSMelDataset(VAL_DIR, SAMPLE_RATE, N_MELS, MAX_LENGTH)
wav2vec2_val_loader = DataLoader(wav2vec2_val_dataset, batch_size=8, shuffle=False)
cnn_val_loader = DataLoader(cnn_val_dataset, batch_size=8, shuffle=False)

# Evaluate Wav2Vec2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
wav2vec2_model.to(device)
wav2vec2_preds, wav2vec2_labels = [], []
with torch.no_grad():
    for batch in wav2vec2_val_loader:
        input_values = batch['input_values'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = wav2vec2_model(input_values, attention_mask=attention_mask)
        wav2vec2_preds.extend(outputs.logits.argmax(dim=-1).cpu().numpy())
        wav2vec2_labels.extend(labels.cpu().numpy())

# Evaluate CNN
cnn_model.to(device)
cnn_preds, cnn_labels = [], []
with torch.no_grad():
    for batch in cnn_val_loader:
        inputs = batch['input'].to(device)
        labels = batch['label'].to(device)
        outputs = cnn_model(inputs)
        cnn_preds.extend(outputs.argmax(dim=-1).cpu().numpy())
        cnn_labels.extend(labels.cpu().numpy())

# Compute metrics
wav2vec2_metrics = compute_metrics(wav2vec2_labels, wav2vec2_preds)
cnn_metrics = compute_metrics(cnn_labels, cnn_preds)

# Print metrics
print("Wav2Vec2 Metrics:")
for metric, value in wav2vec2_metrics.items():
    print(f"{metric.capitalize()}: {value:.4f}")
print("\nCNN Metrics:")
for metric, value in cnn_metrics.items():
    print(f"{metric.capitalize()}: {value:.4f}")

# Plot confusion matrices
plot_confusion_matrix(wav2vec2_labels, wav2vec2_preds, 'Wav2Vec2 Confusion Matrix', EMOTIONS)
plot_confusion_matrix(cnn_labels, cnn_preds, 'CNN Confusion Matrix', EMOTIONS)

# Compare models
metrics_df = pd.DataFrame({
    'Wav2Vec2': wav2vec2_metrics,
    'CNN': cnn_metrics
})
print("\nModel Comparison:")
print(metrics_df)

# Visualize comparison
metrics_df.plot(kind='bar', figsize=(10, 6))
plt.title('Model Performance Comparison')
plt.ylabel('Score')
plt.xticks(rotation=0)
plt.tight_layout()
plt.show()

ModuleNotFoundError: No module named 'seaborn'

In [None]:
!pip install seaborn
