In [2]:
# ----------------------------------------
# Imports and Configuration
# ----------------------------------------
import os
import numpy as np
import pandas as pd
import pretty_midi
from sklearn.model_selection import StratifiedShuffleSplit, train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras import layers, models, Input, Model
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

# Constants
DATA_DIR = os.path.join('.', 'data')
COMPOSERS = ['Bach', 'Beethoven', 'Chopin', 'Mozart']
SUPPORTED_EXT = ('.mid', '.midi')

# ----------------------------------------
# Step 1: MIDI File Collection
# ----------------------------------------
def get_midi_files(data_dir, composers):
    all_paths = []
    for composer in composers:
        folder = os.path.join(data_dir, composer)
        files = [os.path.join(folder, f) for f in os.listdir(folder) if f.lower().endswith(SUPPORTED_EXT)]
        for file in files:
            all_paths.append((file, composer))
    return pd.DataFrame(all_paths, columns=['filepath', 'composer'])

# ----------------------------------------
# Step 2: MIDI Statistics Extraction
# ----------------------------------------
def extract_note_statistics(file_path):
    midi_data = pretty_midi.PrettyMIDI(file_path)
    pitches = [note.pitch for instr in midi_data.instruments for note in instr.notes]
    velocities = [note.velocity for instr in midi_data.instruments for note in instr.notes]
    _, tempos = midi_data.get_tempo_changes()

    return {
        'pitch_std': np.std(pitches),
        'pitch_min': np.min(pitches),
        'pitch_max': np.max(pitches),
        'velocity_mean': np.mean(velocities),
        'velocity_std': np.std(velocities),
        'velocity_min': np.min(velocities),
        'velocity_max': np.max(velocities),
        'tempo_mean': np.mean(tempos),
        'tempo_min': np.min(tempos),
        'tempo_max': np.max(tempos),
    }

def extract_statistics_dataframe(df):
    stats = []
    for _, row in df.iterrows():
        try:
            feats = extract_note_statistics(row['filepath'])
            feats['composer'] = row['composer']
            stats.append(feats)
        except Exception as e:
            print(f"Error in stats for {row['filepath']}: {e}")
    return pd.DataFrame(stats)

# ----------------------------------------
# Step 3: Multichannel Piano Roll
# ----------------------------------------
def process_multichannel_midi(file_path, fs=8, max_length=150):
    midi = pretty_midi.PrettyMIDI(file_path)
    piano_roll = midi.get_piano_roll(fs=fs)
    binary_roll = (piano_roll > 0).astype(np.float32)
    velocity_roll = (piano_roll / 127.0).astype(np.float32)

    instrument_rolls = [instr.get_piano_roll(fs=fs) for instr in midi.instruments]
    max_len = max((r.shape[1] for r in instrument_rolls), default=0)
    inst_combined = np.zeros((128, max_len))
    for roll in instrument_rolls:
        roll = np.pad(roll, ((0, 0), (0, max_len - roll.shape[1])), mode='constant')
        inst_combined += (roll > 0).astype(np.float32)
    inst_combined = (inst_combined / inst_combined.max()).astype(np.float32) if inst_combined.max() > 0 else inst_combined

    expressive_roll = np.zeros((128, max_len))
    for instr in midi.instruments:
        for note in instr.notes:
            start = int(note.start * fs)
            end = int(note.end * fs)
            expressive_roll[note.pitch, start:end] = 1
    expressive_roll = expressive_roll.astype(np.float32)

    def fix_length(arr): return np.pad(arr, ((0, 0), (0, max(0, max_length - arr.shape[1]))), mode='constant')[:, :max_length]
    binary_roll = fix_length(binary_roll)
    velocity_roll = fix_length(velocity_roll)
    inst_combined = fix_length(inst_combined)
    expressive_roll = fix_length(expressive_roll)

    return np.stack([binary_roll, velocity_roll, inst_combined, expressive_roll], axis=-1)

# ----------------------------------------
# Step 4: Dataset Creation
# ----------------------------------------
def create_combined_dataset(df, fs=8, max_length=150):
    X_roll, X_stats, y = [], [], []
    label_map = {name: idx for idx, name in enumerate(COMPOSERS)}
    for _, row in df.iterrows():
        try:
            roll = process_multichannel_midi(row['filepath'], fs=fs, max_length=max_length)
            stats = extract_note_statistics(row['filepath'])
            X_roll.append(roll)
            X_stats.append(list(stats.values()))
            y.append(label_map[row['composer']])
        except Exception as e:
            print(f"Error processing {row['filepath']}: {e}")
    return (np.array(X_roll).astype(np.float32),
            np.array(X_stats).astype(np.float32),
            to_categorical(np.array(y), num_classes=len(COMPOSERS)))

# ----------------------------------------
# Step 5: Split Dataset
# ----------------------------------------
def stratified_split_dataframe(df, label_col='composer', test_size=0.2, val_size=0.1, random_state=42):
    label_map = {label: idx for idx, label in enumerate(sorted(df[label_col].unique()))}
    y = df[label_col].map(label_map).values
    sss1 = StratifiedShuffleSplit(n_splits=1, test_size=(test_size + val_size), random_state=random_state)
    for train_idx, temp_idx in sss1.split(df, y):
        train_df, temp_df = df.iloc[train_idx], df.iloc[temp_idx]
    sss2 = StratifiedShuffleSplit(n_splits=1, test_size=test_size / (test_size + val_size), random_state=random_state)
    y_temp = temp_df[label_col].map(label_map).values
    for val_idx, test_idx in sss2.split(temp_df, y_temp):
        val_df, test_df = temp_df.iloc[val_idx], temp_df.iloc[test_idx]
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)

# ----------------------------------------
# Step 6: Model Building (Hybrid CNN + MLP)
# ----------------------------------------
def build_hybrid_model(piano_roll_shape, stats_shape, num_classes):
    roll_input = Input(shape=piano_roll_shape, name='roll_input')
    x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(roll_input)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Flatten()(x)

    stats_input = Input(shape=(stats_shape,), name='stats_input')
    s = layers.Dense(64, activation='relu')(stats_input)

    merged = layers.concatenate([x, s])
    merged = layers.Dense(128, activation='relu')(merged)
    merged = layers.Dropout(0.4)(merged)
    output = layers.Dense(num_classes, activation='softmax')(merged)

    model = Model(inputs=[roll_input, stats_input], outputs=output)
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

# ----------------------------------------
# Step 7: Class Weights
# ----------------------------------------
def compute_class_weights(y_train_onehot):
    y_int = np.argmax(y_train_onehot, axis=1)
    weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_int), y=y_int)
    return dict(enumerate(weights))

# ----------------------------------------
# Step 8: Training Wrapper
# ----------------------------------------
def train_hybrid_model(model, X_roll_train, X_stats_train, y_train, 
                       X_roll_val, X_stats_val, y_val, epochs=20, batch_size=32):
    class_weights = compute_class_weights(y_train)
    callbacks = [
        EarlyStopping(patience=5, restore_best_weights=True),
        ReduceLROnPlateau(patience=3, factor=0.5)
    ]
    model.fit(
        {'roll_input': X_roll_train, 'stats_input': X_stats_train},
        y_train,
        validation_data=({'roll_input': X_roll_val, 'stats_input': X_stats_val}, y_val),
        epochs=epochs,
        batch_size=batch_size,
        class_weight=class_weights,
        callbacks=callbacks,
        verbose=1
    )
    return model


# ----------------------------------------
# 1. Load MIDI File Paths
# ----------------------------------------
df = get_midi_files(DATA_DIR, COMPOSERS)
df = df[df['filepath'].apply(os.path.exists)]  # Just to be safe
df = df.reset_index(drop=True)

# ----------------------------------------
# 2. Split into Train/Val/Test
# ----------------------------------------
train_df, val_df, test_df = stratified_split_dataframe(df)

# ----------------------------------------
# 3. Extract Features: Multichannel Piano Roll + Statistics
# ----------------------------------------
print("Extracting training data...")
X_train_roll, X_train_stats, y_train = create_combined_dataset(train_df)
print("Extracting validation data...")
X_val_roll, X_val_stats, y_val = create_combined_dataset(val_df)
print("Extracting test data...")
X_test_roll, X_test_stats, y_test = create_combined_dataset(test_df)

# ----------------------------------------
# 4. Build Hybrid Model
# ----------------------------------------
input_shape_roll = X_train_roll.shape[1:]    # (128, 150, 4)
input_shape_stats = X_train_stats.shape[1]   # e.g., 10

model = build_hybrid_model(input_shape_roll, input_shape_stats, num_classes=len(COMPOSERS))

# ----------------------------------------
# 5. Train the Model
# ----------------------------------------
trained_model = train_hybrid_model(model, 
                                   X_train_roll, X_train_stats, y_train,
                                   X_val_roll, X_val_stats, y_val)

# ----------------------------------------
# 6. Evaluate the Model (optional)
# ----------------------------------------
test_loss, test_acc = trained_model.evaluate(
    {'roll_input': X_test_roll, 'stats_input': X_test_stats}, y_test, verbose=0
)
print(f"✅ Final Test Accuracy: {test_acc:.2%}")


Extracting training data...




Extracting validation data...
Error processing .\data\Beethoven\Anhang 14-3.mid: Could not decode key with 3 flats and mode 255
Extracting test data...
Error processing .\data\Mozart\K281 Piano Sonata n03 3mov.mid: Could not decode key with 2 flats and mode 2
Epoch 1/20
[1m36/36[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 84ms/step - accuracy: 0.3167 - loss: 2.5605 - val_accuracy: 0.5644 - val_loss: 1.0759 - learning_rate: 0.0010
Epoch 2/20
[1m36/36[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 78ms/step - accuracy: 0.5591 - loss: 1.2815 - val_accuracy: 0.6380 - val_loss: 0.9917 - learning_rate: 0.0010
Epoch 3/20
[1m36/36[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 78ms/step - accuracy: 0.6084 - loss: 1.1998 - val_accuracy: 0.6933 - val_loss: 0.8412 - learning_rate: 0.0010
Epoch 4/20
[1m36/36[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 77ms/step - accuracy: 0.6603 - loss: 1.1045 - val_accuracy: 0.7117 - val_loss: 0.9162 - learning_rate: 0.0010
