In [26]:
# Importing necessary libraries
import pretty_midi
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from sklearn.metrics import confusion_matrix, classification_report

In [21]:
# Creating function for converting MIDI file to a piano roll representation
def midi_to_pianoroll(filepath, fs=100):
    try:
        midi = pretty_midi.PrettyMIDI(filepath)
        piano_roll = midi.get_piano_roll(fs=fs)
        return piano_roll.astype(np.float32)
    except Exception as e:
        raise ValueError(f"Failed to process MIDI file '{filepath}': {e}")

In [20]:
# Creating Dataset Class for PyTorch
class MidiComposerDataset(Dataset):
    def __init__(self, root_dir, fs=100, max_length=1000, allowed_composers=None):
        self.samples = []
        self.labels = []
        self.label_map = {}
        self.fs = fs
        self.max_length = max_length

        if allowed_composers is None:
            allowed_composers = ["Bach", "Chopin", "Mozart", "Beethoven"]

        for idx, composer in enumerate(sorted(allowed_composers)):
            self.label_map[composer] = idx
            composer_dir = os.path.join(root_dir, composer)
            if not os.path.isdir(composer_dir):
                continue

            for root, _, files in os.walk(composer_dir):
                for file in files:
                    if file.lower().endswith((".mid", ".midi")):
                        filepath = os.path.join(root, file)
                        self.samples.append(filepath)
                        self.labels.append(idx)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        filepath = self.samples[idx]
        label = self.labels[idx]

        try:
            piano_roll = midi_to_pianoroll(filepath, fs=self.fs)
        except Exception as e:
            print(f"[WARNING] Skipping corrupt file: {filepath}\n   Reason: {e}")
            new_idx = np.random.randint(0, len(self))
            return self.__getitem__(new_idx)

        if piano_roll.shape[1] > self.max_length:
            piano_roll = piano_roll[:, :self.max_length]
        else:
            pad_width = self.max_length - piano_roll.shape[1]
            piano_roll = np.pad(piano_roll, ((0, 0), (0, pad_width)), mode='constant')

        piano_roll = torch.tensor(piano_roll).unsqueeze(0)  # (1, 128, T)
        label_tensor = torch.tensor(label, dtype=torch.long)
        return piano_roll, label_tensor

In [22]:
# Checking that preprocessing with pretty_midi is working correctly
dataset = MidiComposerDataset(root_dir="midiclassics")
print(dataset.label_map)

{'Bach': 0, 'Beethoven': 1, 'Chopin': 2, 'Mozart': 3}


In [23]:
# Creating function for a stratified train/test split
def stratified_split(dataset, test_size=0.2, random_state=0):
    y = dataset.labels
    train_idx, test_idx = train_test_split(
        range(len(y)),
        test_size=test_size,
        random_state=random_state,
        stratify=y
    )
    
    return Subset(dataset, train_idx), Subset(dataset, test_idx)

In [24]:
# Creating function for a CNN/LSTM Classifier
class CNNLSTMClassifier(nn.Module):
    def __init__(self, input_channels=1, num_classes=4, lstm_hidden=128):
        super(CNNLSTMClassifier, self).__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(input_channels, 16, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2)),  

            nn.Conv2d(16, 32, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2)),  
        )

        self.lstm_input_size = 32 * 128  

        self.lstm = nn.LSTM(
            input_size=self.lstm_input_size,
            hidden_size=lstm_hidden,
            num_layers=1,
            batch_first=True
        )

        self.fc = nn.Linear(lstm_hidden, num_classes)

    def forward(self, x):
        x = self.cnn(x)  
        x = x.permute(0, 3, 1, 2)  
        B, T, C, H = x.shape
        x = x.reshape(B, T, C * H)  

        lstm_out, _ = self.lstm(x)  
        out = self.fc(lstm_out[:, -1, :])
        return out

In [25]:
# Testing functions with simple model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dataset, val_dataset = stratified_split(dataset)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

model = CNNLSTMClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Doing a single training epoch for testing purposes
model.train()
for batch in train_loader:
    inputs, labels = batch
    inputs = inputs.to(device)
    labels = labels.to(device)
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()



   Reason: Failed to process MIDI file 'midiclassics\Mozart\Piano Sonatas\Nueva carpeta\K281 Piano Sonata n03 3mov.mid': Could not decode key with 2 flats and mode 2
   Reason: Failed to process MIDI file 'midiclassics\Beethoven\Anhang 14-3.mid': Could not decode key with 3 flats and mode 255


In [27]:
# Evaluating test model
model.eval()

correct = 0
total = 0
losses = []

criterion = nn.CrossEntropyLoss() 

with torch.no_grad():
    for inputs, labels in val_loader:  
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        losses.append(loss.item())

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
avg_loss = sum(losses) / len(losses)

print(f"Evaluation Results:")
print(f"Loss: {avg_loss:.4f}")
print(f"Accuracy: {accuracy:.2f}%")



Evaluation Results:
  Loss: 0.9457
  Accuracy: 64.11%
