In [1]:
%pip install pretty_midi kagglehub

Collecting pretty_midi
  Downloading pretty_midi-0.2.10.tar.gz (5.6 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/5.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m5.6/5.6 MB[0m [31m178.5 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m107.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mido>=1.1.16 (from pretty_midi)
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: pretty_midi
  Building wheel for pretty_midi (setup.py) ... [?25l[?25hdone
  Created wheel for pretty_midi: filename=pretty_midi-0.2.10-py3-none-any.whl size=5592286 sha256=6b7d68d4fa0805eb4

In [2]:
import os
import kagglehub
import zipfile
import shutil
import numpy as np
import torch
import torch.nn as nn
import os
import numpy as np
import pretty_midi
import torch
from torch.utils.data import Dataset

In [3]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
if not hasattr(np, 'int'):
    np.int = int

In [None]:
# Define your model class (must match the architecture used in final-project1)
class CNN_LSTM_Classifier(nn.Module):
    def __init__(self, num_classes=4, lstm_hidden=256):
        super(CNN_LSTM_Classifier, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.MaxPool2d(kernel_size=(2, 2))
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout2d(0.3),
            nn.MaxPool2d(kernel_size=(2, 2))
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout2d(0.4),
            nn.MaxPool2d(kernel_size=(2, 2))
        )
        self.feature_size = 128 * 16
        self.lstm = nn.LSTM(
            input_size=self.feature_size,
            hidden_size=lstm_hidden,
            num_layers=2,
            batch_first=True,
            dropout=0.3,
            bidirectional=True
        )
        self.attention = nn.MultiheadAttention(
            embed_dim=lstm_hidden * 2,
            num_heads=8,
            dropout=0.3,
            batch_first=True
        )
        self.classifier = nn.Sequential(
            nn.Linear(lstm_hidden * 2, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        batch_size = x.size(0)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.permute(0, 3, 1, 2)
        x = x.contiguous().view(batch_size, x.size(1), -1)
        lstm_out, _ = self.lstm(x)
        attn_out, _ = self.attention(lstm_out, lstm_out, lstm_out)
        pooled = torch.mean(attn_out, dim=1)
        output = self.classifier(pooled)
        return output


In [None]:
# Load the models
from IPython import get_ipython

if 'google.colab' in str(get_ipython()):
    print("Running in Google Colab.")
    original_path = os.path.join('/content/saved_models', 'original_cnn_lstm.pth')
    rhythm_path = os.path.join('/content/saved_models', 'rhythm_augmented_cnn_lstm.pth')
else:
    print("Not running in Google Colab.")
    original_path = os.path.join('saved_models', 'original_cnn_lstm.pth')
    rhythm_path = os.path.join('saved_models', 'rhythm_augmented_cnn_lstm.pth')

model = CNN_LSTM_Classifier(num_classes=4, lstm_hidden=256).to(device)
model.load_state_dict(torch.load(original_path, map_location=device))
print(f"✅ Loaded: {original_path}")

rhythm_model = CNN_LSTM_Classifier(num_classes=4, lstm_hidden=256).to(device)
rhythm_model.load_state_dict(torch.load(rhythm_path, map_location=device))
print(f"✅ Loaded: {rhythm_path}")

Running in Google Colab.
✅ Loaded: /content/saved_models/original_cnn_lstm.pth
✅ Loaded: /content/saved_models/rhythm_augmented_cnn_lstm.pth


In [4]:
TARGET_COMPOSERS = [
    'Bach',
    'Beethoven',
    'Chopin',
    'Mozart',
]

path = kagglehub.dataset_download("blanderbuss/midi-classic-music")

zip_path = os.path.join(path, 'midiclassics.zip')
extract_path = os.path.join('data', 'kaggle', 'midiclassics')


with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

# list files in extract_path that contain the target composers in name
for composer in TARGET_COMPOSERS:
    composer_files = [f for f in os.listdir(extract_path) if composer.lower() in f.lower()]

# Only keep directories that contain a target composer's name
for item in os.listdir(extract_path):
    item_path = os.path.join(extract_path, item)
    if not any(composer.lower() in item.lower() for composer in TARGET_COMPOSERS):
        if os.path.isfile(item_path):
            os.remove(item_path)
        elif os.path.isdir(item_path):
            shutil.rmtree(item_path)

# also delete "C.P.E.Bach" files. This was the son of J.S. Bach, and we want to keep only the main composers
for item in os.listdir(extract_path):
    if 'C.P.E.Bach' in item:
        item_path = os.path.join(extract_path, item)
        if os.path.isfile(item_path):
            os.remove(item_path)
        elif os.path.isdir(item_path):
            shutil.rmtree(item_path)

In [5]:
class PianoRollDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        # Add channel dimension for CNN: (1, 128, T)
        return self.data[idx].unsqueeze(0), self.labels[idx]

In [6]:
def get_piano_roll_improved(midi_path, fs=100, target_duration=45.0):
    """
    Improved MIDI to piano roll conversion with musical awareness
    """
    try:
        pm = pretty_midi.PrettyMIDI(midi_path)

        # Get the actual duration of the piece
        actual_duration = pm.get_end_time()

        # If piece is very short, skip it
        if actual_duration < 10.0:  # Less than 10 seconds
            return None

        # For long pieces, extract multiple segments
        if actual_duration > target_duration * 1.5:
            # Extract from different parts of the piece
            segments = []
            num_segments = min(3, int(actual_duration // target_duration))

            for i in range(num_segments):
                start_time = i * (actual_duration / num_segments)
                end_time = start_time + target_duration

                # Create a copy and trim
                pm_segment = pretty_midi.PrettyMIDI()
                for instrument in pm.instruments:
                    new_instrument = pretty_midi.Instrument(
                        program=instrument.program,
                        is_drum=instrument.is_drum,
                        name=instrument.name
                    )

                    for note in instrument.notes:
                        if start_time <= note.start < end_time:
                            new_note = pretty_midi.Note(
                                velocity=note.velocity,
                                pitch=note.pitch,
                                start=note.start - start_time,
                                end=min(note.end - start_time, target_duration)
                            )
                            new_instrument.notes.append(new_note)

                    if new_instrument.notes:
                        pm_segment.instruments.append(new_instrument)

                if pm_segment.instruments:
                    piano_roll = pm_segment.get_piano_roll(fs=fs)
                    target_length = int(target_duration * fs)

                    if piano_roll.shape[1] > target_length:
                        piano_roll = piano_roll[:, :target_length]
                    else:
                        pad_width = target_length - piano_roll.shape[1]
                        piano_roll = np.pad(piano_roll, ((0,0),(0,pad_width)), mode='constant')

                    segments.append(piano_roll)

            return segments

        else:
            # For normal length pieces, use the whole piece
            piano_roll = pm.get_piano_roll(fs=fs)
            target_length = int(target_duration * fs)

            if piano_roll.shape[1] > target_length:
                # Take from the middle rather than truncating end
                start_idx = (piano_roll.shape[1] - target_length) // 2
                piano_roll = piano_roll[:, start_idx:start_idx + target_length]
            else:
                pad_width = target_length - piano_roll.shape[1]
                piano_roll = np.pad(piano_roll, ((0,0),(0,pad_width)), mode='constant')

            return [piano_roll]

    except Exception as e:
        print(f"Error processing {midi_path}: {e}")
        return None


In [7]:
def normalize_piano_roll(piano_roll):
    """
    Apply musical normalization to piano roll
    """
    # 1. Velocity normalization (already 0-1 from pretty_midi)
    normalized = piano_roll.copy()

    # 2. Optional: Focus on active pitch range
    active_pitches = np.any(normalized > 0, axis=1)
    if np.any(active_pitches):
        first_active = np.argmax(active_pitches)
        last_active = len(active_pitches) - 1 - np.argmax(active_pitches[::-1])

        # Ensure we keep a reasonable range (at least 60 semitones = 5 octaves)
        min_range = 60
        current_range = last_active - first_active + 1

        if current_range < min_range:
            expand = (min_range - current_range) // 2
            first_active = max(0, first_active - expand)
            last_active = min(127, last_active + expand)

    return normalized


In [8]:
def extract_musical_features(piano_roll):
    """
    Extract features that capture musical style
    """
    features = {}

    # Temporal features
    note_density_timeline = np.sum(piano_roll > 0, axis=0)
    features['avg_notes_per_time'] = np.mean(note_density_timeline)
    features['note_density_variance'] = np.var(note_density_timeline)

    # Pitch features
    pitch_activity = np.sum(piano_roll > 0, axis=1)
    active_pitches = pitch_activity > 0
    if np.any(active_pitches):
        features['pitch_range'] = np.sum(active_pitches)
        features['lowest_pitch'] = np.argmax(active_pitches)
        features['highest_pitch'] = 127 - np.argmax(active_pitches[::-1])
    else:
        features['pitch_range'] = 0
        features['lowest_pitch'] = 60  # Middle C
        features['highest_pitch'] = 60

    # Rhythmic features
    onset_pattern = np.diff(note_density_timeline > 0).astype(int)
    features['onset_density'] = np.sum(onset_pattern == 1) / len(onset_pattern)

    return features

print("✅ Improved data processing functions defined!")
print("Key improvements:")
print("• Intelligent segment extraction for long pieces")
print("• Musical boundary awareness")
print("• Better normalization")
print("• Feature extraction for analysis")

✅ Improved data processing functions defined!
Key improvements:
• Intelligent segment extraction for long pieces
• Musical boundary awareness
• Better normalization
• Feature extraction for analysis


In [10]:
def get_piano_roll_non_overlapping_segments(midi_path, fs=100, segment_duration=45.0):
    """
    Extract non-overlapping segments to avoid data leakage with LSTM
    """
    try:
        pm = pretty_midi.PrettyMIDI(midi_path)

        # Get the actual duration of the piece
        actual_duration = pm.get_end_time()

        # If piece is very short, skip it
        if actual_duration < 15.0:  # Less than 15 seconds
            return None

        segments = []
        segment_size = segment_duration

        # Calculate number of non-overlapping segments
        num_segments = int(actual_duration // segment_size)

        # If piece is shorter than one segment, use the whole piece (padded)
        if num_segments == 0:
            piano_roll = pm.get_piano_roll(fs=fs)
            target_length = int(segment_duration * fs)

            if piano_roll.shape[1] < target_length:
                pad_width = target_length - piano_roll.shape[1]
                piano_roll = np.pad(piano_roll, ((0,0),(0,pad_width)), mode='constant')
            else:
                piano_roll = piano_roll[:, :target_length]

            segments.append(piano_roll)
            return segments

        # Extract non-overlapping segments
        for i in range(num_segments):
            start_time = i * segment_size
            end_time = start_time + segment_size

            # Create a copy and trim
            pm_segment = pretty_midi.PrettyMIDI()
            for instrument in pm.instruments:
                new_instrument = pretty_midi.Instrument(
                    program=instrument.program,
                    is_drum=instrument.is_drum,
                    name=instrument.name
                )

                for note in instrument.notes:
                    if start_time <= note.start < end_time:
                        new_note = pretty_midi.Note(
                            velocity=note.velocity,
                            pitch=note.pitch,
                            start=note.start - start_time,
                            end=min(note.end - start_time, segment_duration)
                        )
                        new_instrument.notes.append(new_note)

                if new_instrument.notes:
                    pm_segment.instruments.append(new_instrument)

            if pm_segment.instruments:
                piano_roll = pm_segment.get_piano_roll(fs=fs)
                target_length = int(segment_duration * fs)

                if piano_roll.shape[1] > target_length:
                    piano_roll = piano_roll[:, :target_length]
                else:
                    pad_width = target_length - piano_roll.shape[1]
                    piano_roll = np.pad(piano_roll, ((0,0),(0,pad_width)), mode='constant')

                segments.append(piano_roll)

        return segments if segments else None

    except Exception as e:
        print(f"Error processing {midi_path}: {e}")
        return None

print("✅ Non-overlapping segmentation function defined!")

✅ Non-overlapping segmentation function defined!


In [11]:
def load_dataset_non_overlapping_full(extract_path, target_composers, segment_duration=45.0):
    """
    Load FULL dataset with non-overlapping segments - using ALL available files
    """
    print("🎵 LOADING FULL DATASET WITH NON-OVERLAPPING SEGMENTS...")
    print("Benefits for LSTM training:")
    print(f"• Segment duration: {segment_duration}s (no overlap)")
    print("• Using ALL available files (no limits)")
    print("• No data leakage between train/test")
    print("• Cleaner temporal boundaries")
    print("• Will address class imbalance later during training")

    composer_to_idx = {c: i for i, c in enumerate(target_composers)}
    all_data = []
    all_labels = []
    all_features = []

    total_files_processed = 0
    total_segments_created = 0

    for composer in target_composers:
        print(f"\n--- Processing {composer} ---")
        composer_dir = os.path.join(extract_path, composer)

        if not os.path.isdir(composer_dir):
            print(f"Directory not found: {composer_dir}")
            continue

        composer_data = []
        composer_labels = []
        composer_features = []
        files_processed = 0
        segments_created = 0

        # Get ALL MIDI files - no limit
        midi_files = [f for f in os.listdir(composer_dir)
                     if f.lower().endswith(('.mid', '.midi'))]

        print(f"  Found {len(midi_files)} MIDI files for {composer}")

        for file in midi_files:
            midi_path = os.path.join(composer_dir, file)

            try:
                # Use non-overlapping segmentation
                segments = get_piano_roll_non_overlapping_segments(
                    midi_path,
                    segment_duration=segment_duration
                )

                if segments is None:
                    continue

                for segment in segments:
                    # Normalize the segment
                    normalized_segment = normalize_piano_roll(segment)

                    # Extract musical features
                    features = extract_musical_features(normalized_segment)

                    # Quality check: skip if too sparse
                    note_density = features['avg_notes_per_time']
                    if note_density < 0.1:  # Very sparse, likely poor quality
                        continue

                    composer_data.append(normalized_segment)
                    composer_labels.append(composer_to_idx[composer])
                    composer_features.append(features)
                    segments_created += 1

                files_processed += 1

                # Progress update every 50 files for full dataset
                if files_processed % 50 == 0:
                    print(f"  Processed {files_processed}/{len(midi_files)} files, created {segments_created} segments...")

            except Exception as e:
                print(f"  Error processing {file}: {e}")
                continue

        print(f"✅ {composer}: {files_processed}/{len(midi_files)} files → {segments_created} segments")

        if composer_data:
            composer_data = np.array(composer_data)
            composer_labels = np.array(composer_labels)

            all_data.append(composer_data)
            all_labels.append(composer_labels)
            all_features.extend(composer_features)

            print(f"  Final data shape: {composer_data.shape}")

        total_files_processed += files_processed
        total_segments_created += segments_created

    # Combine all data
    if all_data:
        data = np.concatenate(all_data, axis=0)
        labels = np.concatenate(all_labels, axis=0)

        print(f"\n🎯 FINAL FULL NON-OVERLAPPING DATASET:")
        print(f"Total files processed: {total_files_processed}")
        print(f"Total samples: {len(data)}")
        print(f"Data shape: {data.shape}")
        print(f"Label distribution: {np.bincount(labels)}")

        # Show class distribution percentages
        for i, composer in enumerate(target_composers):
            count = np.sum(labels == i)
            percentage = (count / len(labels)) * 100
            print(f"  {composer}: {count} samples ({percentage:.1f}%)")

        return data, labels, all_features
    else:
        print("❌ No data loaded!")
        return None, None, None

# Load the FULL dataset with non-overlapping segments
print("🚀 Starting FULL non-overlapping segment data loading...")
full_data, full_labels, full_features = load_dataset_non_overlapping_full(
    extract_path,
    TARGET_COMPOSERS,
    segment_duration=45.0  # 45-second segments, no overlap, ALL files
)

🚀 Starting FULL non-overlapping segment data loading...
🎵 LOADING FULL DATASET WITH NON-OVERLAPPING SEGMENTS...
Benefits for LSTM training:
• Segment duration: 45.0s (no overlap)
• Using ALL available files (no limits)
• No data leakage between train/test
• Cleaner temporal boundaries
• Will address class imbalance later during training

--- Processing Bach ---
  Found 131 MIDI files for Bach
  Processed 50/131 files, created 394 segments...
  Processed 100/131 files, created 715 segments...
✅ Bach: 131/131 files → 977 segments
  Final data shape: (977, 128, 4500)

--- Processing Beethoven ---
  Found 134 MIDI files for Beethoven
Error processing data/kaggle/midiclassics/Beethoven/Anhang 14-3.mid: Could not decode key with 3 flats and mode 255
  Processed 50/134 files, created 357 segments...
  Processed 100/134 files, created 765 segments...
✅ Beethoven: 133/134 files → 987 segments
  Final data shape: (987, 128, 4500)

--- Processing Chopin ---
  Found 136 MIDI files for Chopin
  Pro