In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
from torchvision import transforms

from utils import train_model, evaluate_model
from Datasets import MelSpectrogramDataset
# Models
from Models import CRNN

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
transform = transforms.Compose([
    transforms.Lambda(lambda x: (x - np.mean(x)) / (np.std(x) + 1e-6))  # Normalize Mel spectrogram
])

train_dir = "C:/Users/jimmy/Desktop/Practical_Work/processed_data/mel_spectrogram/train"
test_dir = "C:/Users/jimmy/Desktop/Practical_Work/processed_data/mel_spectrogram/test"

In [3]:
train_dataset = MelSpectrogramDataset(
    features_dir=train_dir,
    labels_path=os.path.join(train_dir, "labels.npy"),
    transform=transform
)

test_dataset = MelSpectrogramDataset(
    features_dir=test_dir,
    labels_path=os.path.join(test_dir, "labels.npy"),
    transform=transform
)

# DataLoaders for train and test datasets
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Device configuration (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## Models


In [4]:
model = CRNN(
    input_channels=1,
    img_height=128,
    img_width=216,
    num_classes=50
)

In [5]:
# Train the model
train_model(model, train_loader, test_loader, device, num_epochs=100)

# Evaluate the model on the test set
test_accuracy = evaluate_model(model, test_loader, device)
print(f"Test Accuracy: {test_accuracy:.2f}%")


Epoch [1/100], Loss: 3.8373, Accuracy: 3.75%
Epoch [2/100], Loss: 3.5848, Accuracy: 5.94%
Epoch [3/100], Loss: 3.4412, Accuracy: 6.69%
Epoch [4/100], Loss: 3.2612, Accuracy: 10.62%
Epoch [5/100], Loss: 3.2328, Accuracy: 9.81%
Epoch [6/100], Loss: 3.1658, Accuracy: 12.75%
Epoch [7/100], Loss: 3.0839, Accuracy: 13.38%
Epoch [8/100], Loss: 2.9564, Accuracy: 16.06%
Epoch [9/100], Loss: 2.8304, Accuracy: 19.62%
Epoch [10/100], Loss: 2.7900, Accuracy: 20.44%
Epoch [11/100], Loss: 2.7223, Accuracy: 22.06%
Epoch [12/100], Loss: 2.6458, Accuracy: 23.06%
Epoch [13/100], Loss: 2.5247, Accuracy: 25.50%
Epoch [14/100], Loss: 2.4736, Accuracy: 27.50%
Epoch [15/100], Loss: 2.4225, Accuracy: 28.88%
Epoch [16/100], Loss: 2.3308, Accuracy: 30.94%
Epoch [17/100], Loss: 2.3196, Accuracy: 30.88%
Epoch [18/100], Loss: 2.1609, Accuracy: 33.25%
Epoch [19/100], Loss: 2.0903, Accuracy: 35.62%
Epoch [20/100], Loss: 2.0499, Accuracy: 38.69%
Epoch [21/100], Loss: 1.9979, Accuracy: 39.88%
Epoch [22/100], Loss: 1.87