In [None]:
import torch
import torch.nn as nn
import librosa
import math
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

In [None]:
data_directory = Path.cwd().parent / "data"
genres_dir = data_directory / "genres"

genres = {
    genre_dir.name: list(genre_dir.glob("*.wav"))
    for genre_dir in genres_dir.iterdir()
    if genre_dir.is_dir()
}

In [None]:
# Hyperparameters
num_classes = 10 # number of genres
input_size = 20 # number of MFCC coefficients
hidden_size = 128
num_layers = 2 
batch_size = 64
num_epochs = 10
learning_rate = 1e-3

In [None]:
# LSTM model
class LSTMGenreModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        # x shape: (batch, seq_len, input_size)
        out, (hn, cn) = self.lstm(x)
        # out shape: (batch, seq_len, hidden_size)
        
        # Take the final output and classify 
        out = self.fc(out[:, -1, :]) 
        # out shape: (batch, num_classes)
        return out

# Dataset
class MFCCDataset(Dataset):
    def __init__(self, X, y):
        self.X = X 
        self.y = y
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        mfccs = self.X[idx]
        label = self.y[idx]
        return mfccs, label
        

In [None]:
SAMPLE_RATE = 22050
DURATION = 30  # length of audio files measured in seconds
NUM_SEGMENTS = 1
SAMPLES_PER_TRACK = SAMPLE_RATE * DURATION

In [None]:
# Generate MFCC features
mfccs = []
labels = []

# Create genre string to index mapping
genre2idx = {genre: i for i, genre in enumerate(genres)}

# MFCC parameters
num_segments = 5
hop_length = 512
num_samples_per_segment = int(SAMPLES_PER_TRACK / num_segments) 
expected_num_mfcc_vectors_per_segment = math.ceil(num_samples_per_segment / hop_length)

# Extract MFCCs
for genre, paths in genres.items():

  for path in paths:
    
    try:
      signal, sr = librosa.load(path)
      
      # Split audio into fixed segments
      for s in range(num_segments):
        start_sample = num_samples_per_segment * s  
        finish_sample = start_sample + num_samples_per_segment
        
        mfcc = librosa.feature.mfcc(y=signal[start_sample:finish_sample], 
                                    sr=sr)
        
        mfcc = mfcc.T
        
        # Only keep MFCCs of expected length
        if len(mfcc) == expected_num_mfcc_vectors_per_segment:
          
          mfccs.append(mfcc)
          labels.append(genre2idx[genre])
          print(f"{path}, segment: {s+1}")
      
    except Exception as e:
      print(f"Error loading {path}: {e}")

In [None]:
# Rest of training loop...
# Split data
X_train, X_test, y_train, y_test = train_test_split(
    mfccs, labels, test_size=0.2, random_state=42
)

# Create datasets and dataloaders
train_dataset = MFCCDataset(X_train, y_train)
test_dataset = MFCCDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# Train model
model = LSTMGenreModel(input_size, hidden_size, num_layers, num_classes)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    for mfccs, labels in train_loader:
    
        optimizer.zero_grad()

        # Forward pass
        outputs = model(mfccs)

        # Calculate loss
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()

        # Update weights
        optimizer.step()

In [None]:
# Evaluation
correct = 0
total = 0
with torch.no_grad():
    for mfccs, labels in test_loader:
        outputs = model(mfccs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print("Accuracy: {}%".format(100 * correct / total))
