In [1]:
%pip install pretty_midi kagglehub

Collecting pretty_midi
  Downloading pretty_midi-0.2.10.tar.gz (5.6 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/5.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m5.6/5.6 MB[0m [31m178.5 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m107.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mido>=1.1.16 (from pretty_midi)
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: pretty_midi
  Building wheel for pretty_midi (setup.py) ... [?25l[?25hdone
  Created wheel for pretty_midi: filename=pretty_midi-0.2.10-py3-none-any.whl size=5592286 sha256=6b7d68d4fa0805eb4

In [2]:
import os
import kagglehub
import zipfile
import shutil
import numpy as np
import torch
import torch.nn as nn
import os
import numpy as np
import pretty_midi
import torch
from torch.utils.data import Dataset

In [3]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
if not hasattr(np, 'int'):
    np.int = int

In [None]:
# Define your model class (must match the architecture used in final-project1)
class CNN_LSTM_Classifier(nn.Module):
    def __init__(self, num_classes=4, lstm_hidden=256):
        super(CNN_LSTM_Classifier, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.MaxPool2d(kernel_size=(2, 2))
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout2d(0.3),
            nn.MaxPool2d(kernel_size=(2, 2))
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout2d(0.4),
            nn.MaxPool2d(kernel_size=(2, 2))
        )
        self.feature_size = 128 * 16
        self.lstm = nn.LSTM(
            input_size=self.feature_size,
            hidden_size=lstm_hidden,
            num_layers=2,
            batch_first=True,
            dropout=0.3,
            bidirectional=True
        )
        self.attention = nn.MultiheadAttention(
            embed_dim=lstm_hidden * 2,
            num_heads=8,
            dropout=0.3,
            batch_first=True
        )
        self.classifier = nn.Sequential(
            nn.Linear(lstm_hidden * 2, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        batch_size = x.size(0)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.permute(0, 3, 1, 2)
        x = x.contiguous().view(batch_size, x.size(1), -1)
        lstm_out, _ = self.lstm(x)
        attn_out, _ = self.attention(lstm_out, lstm_out, lstm_out)
        pooled = torch.mean(attn_out, dim=1)
        output = self.classifier(pooled)
        return output


In [None]:
# # Load the models
# from IPython import get_ipython

# if 'google.colab' in str(get_ipython()):
#     print("Running in Google Colab.")
#     original_path = os.path.join('/content/saved_models', 'original_cnn_lstm.pth')
#     rhythm_path = os.path.join('/content/saved_models', 'rhythm_augmented_cnn_lstm.pth')
# else:
#     print("Not running in Google Colab.")
#     original_path = os.path.join('saved_models', 'original_cnn_lstm.pth')
#     rhythm_path = os.path.join('saved_models', 'rhythm_augmented_cnn_lstm.pth')

# model = CNN_LSTM_Classifier(num_classes=4, lstm_hidden=256).to(device)
# model.load_state_dict(torch.load(original_path, map_location=device))
# print(f"✅ Loaded: {original_path}")

# rhythm_model = CNN_LSTM_Classifier(num_classes=4, lstm_hidden=256).to(device)
# rhythm_model.load_state_dict(torch.load(rhythm_path, map_location=device))
# print(f"✅ Loaded: {rhythm_path}")

Running in Google Colab.
✅ Loaded: /content/saved_models/original_cnn_lstm.pth
✅ Loaded: /content/saved_models/rhythm_augmented_cnn_lstm.pth


In [4]:
TARGET_COMPOSERS = [
    'Bach',
    'Beethoven',
    'Chopin',
    'Mozart',
]

path = kagglehub.dataset_download("blanderbuss/midi-classic-music")

zip_path = os.path.join(path, 'midiclassics.zip')
extract_path = os.path.join('data', 'kaggle', 'midiclassics')


with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

# list files in extract_path that contain the target composers in name
for composer in TARGET_COMPOSERS:
    composer_files = [f for f in os.listdir(extract_path) if composer.lower() in f.lower()]

# Only keep directories that contain a target composer's name
for item in os.listdir(extract_path):
    item_path = os.path.join(extract_path, item)
    if not any(composer.lower() in item.lower() for composer in TARGET_COMPOSERS):
        if os.path.isfile(item_path):
            os.remove(item_path)
        elif os.path.isdir(item_path):
            shutil.rmtree(item_path)

# also delete "C.P.E.Bach" files. This was the son of J.S. Bach, and we want to keep only the main composers
for item in os.listdir(extract_path):
    if 'C.P.E.Bach' in item:
        item_path = os.path.join(extract_path, item)
        if os.path.isfile(item_path):
            os.remove(item_path)
        elif os.path.isdir(item_path):
            shutil.rmtree(item_path)

In [5]:
class PianoRollDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        # Add channel dimension for CNN: (1, 128, T)
        return self.data[idx].unsqueeze(0), self.labels[idx]

In [6]:
def get_piano_roll_improved(midi_path, fs=100, target_duration=45.0):
    """
    Improved MIDI to piano roll conversion with musical awareness
    """
    try:
        pm = pretty_midi.PrettyMIDI(midi_path)

        # Get the actual duration of the piece
        actual_duration = pm.get_end_time()

        # If piece is very short, skip it
        if actual_duration < 10.0:  # Less than 10 seconds
            return None

        # For long pieces, extract multiple segments
        if actual_duration > target_duration * 1.5:
            # Extract from different parts of the piece
            segments = []
            num_segments = min(3, int(actual_duration // target_duration))

            for i in range(num_segments):
                start_time = i * (actual_duration / num_segments)
                end_time = start_time + target_duration

                # Create a copy and trim
                pm_segment = pretty_midi.PrettyMIDI()
                for instrument in pm.instruments:
                    new_instrument = pretty_midi.Instrument(
                        program=instrument.program,
                        is_drum=instrument.is_drum,
                        name=instrument.name
                    )

                    for note in instrument.notes:
                        if start_time <= note.start < end_time:
                            new_note = pretty_midi.Note(
                                velocity=note.velocity,
                                pitch=note.pitch,
                                start=note.start - start_time,
                                end=min(note.end - start_time, target_duration)
                            )
                            new_instrument.notes.append(new_note)

                    if new_instrument.notes:
                        pm_segment.instruments.append(new_instrument)

                if pm_segment.instruments:
                    piano_roll = pm_segment.get_piano_roll(fs=fs)
                    target_length = int(target_duration * fs)

                    if piano_roll.shape[1] > target_length:
                        piano_roll = piano_roll[:, :target_length]
                    else:
                        pad_width = target_length - piano_roll.shape[1]
                        piano_roll = np.pad(piano_roll, ((0,0),(0,pad_width)), mode='constant')

                    segments.append(piano_roll)

            return segments

        else:
            # For normal length pieces, use the whole piece
            piano_roll = pm.get_piano_roll(fs=fs)
            target_length = int(target_duration * fs)

            if piano_roll.shape[1] > target_length:
                # Take from the middle rather than truncating end
                start_idx = (piano_roll.shape[1] - target_length) // 2
                piano_roll = piano_roll[:, start_idx:start_idx + target_length]
            else:
                pad_width = target_length - piano_roll.shape[1]
                piano_roll = np.pad(piano_roll, ((0,0),(0,pad_width)), mode='constant')

            return [piano_roll]

    except Exception as e:
        print(f"Error processing {midi_path}: {e}")
        return None


In [7]:
def normalize_piano_roll(piano_roll):
    """
    Apply musical normalization to piano roll
    """
    # 1. Velocity normalization (already 0-1 from pretty_midi)
    normalized = piano_roll.copy()

    # 2. Optional: Focus on active pitch range
    active_pitches = np.any(normalized > 0, axis=1)
    if np.any(active_pitches):
        first_active = np.argmax(active_pitches)
        last_active = len(active_pitches) - 1 - np.argmax(active_pitches[::-1])

        # Ensure we keep a reasonable range (at least 60 semitones = 5 octaves)
        min_range = 60
        current_range = last_active - first_active + 1

        if current_range < min_range:
            expand = (min_range - current_range) // 2
            first_active = max(0, first_active - expand)
            last_active = min(127, last_active + expand)

    return normalized


In [8]:
def extract_musical_features(piano_roll):
    """
    Extract features that capture musical style
    """
    features = {}

    # Temporal features
    note_density_timeline = np.sum(piano_roll > 0, axis=0)
    features['avg_notes_per_time'] = np.mean(note_density_timeline)
    features['note_density_variance'] = np.var(note_density_timeline)

    # Pitch features
    pitch_activity = np.sum(piano_roll > 0, axis=1)
    active_pitches = pitch_activity > 0
    if np.any(active_pitches):
        features['pitch_range'] = np.sum(active_pitches)
        features['lowest_pitch'] = np.argmax(active_pitches)
        features['highest_pitch'] = 127 - np.argmax(active_pitches[::-1])
    else:
        features['pitch_range'] = 0
        features['lowest_pitch'] = 60  # Middle C
        features['highest_pitch'] = 60

    # Rhythmic features
    onset_pattern = np.diff(note_density_timeline > 0).astype(int)
    features['onset_density'] = np.sum(onset_pattern == 1) / len(onset_pattern)

    return features

print("✅ Improved data processing functions defined!")
print("Key improvements:")
print("• Intelligent segment extraction for long pieces")
print("• Musical boundary awareness")
print("• Better normalization")
print("• Feature extraction for analysis")

✅ Improved data processing functions defined!
Key improvements:
• Intelligent segment extraction for long pieces
• Musical boundary awareness
• Better normalization
• Feature extraction for analysis


In [None]:
# =====================================================
# IMPROVED DATA LOADING WITH BETTER PROCESSING
# =====================================================

import numpy as np

def load_improved_dataset(extract_path, target_composers, target_duration=45.0, max_files_per_composer=None):
    """
    Load dataset with improved processing that addresses previous shortcomings
    """
    print("🎵 LOADING DATASET WITH IMPROVED PROCESSING...")
    print("Improvements over original:")
    print("• Intelligent segment extraction for long pieces")
    print("• Better handling of piece lengths")
    print("• Musical feature extraction")
    print("• Quality filtering")

    composer_to_idx = {c: i for i, c in enumerate(target_composers)}
    all_data = []
    all_labels = []
    all_features = []

    for composer in target_composers:
        print(f"\n--- Processing {composer} ---")
        composer_dir = os.path.join(extract_path, composer)

        if not os.path.isdir(composer_dir):
            print(f"Directory not found: {composer_dir}")
            continue

        composer_data = []
        composer_labels = []
        composer_features = []
        files_processed = 0
        segments_created = 0

        midi_files = [f for f in os.listdir(composer_dir)
                     if f.lower().endswith(('.mid', '.midi'))]

        if max_files_per_composer:
            midi_files = midi_files[:max_files_per_composer]

        for file in midi_files:
            midi_path = os.path.join(composer_dir, file)

            try:
                # Use improved processing
                segments = get_piano_roll_improved(midi_path, target_duration=target_duration)

                if segments is None:
                    continue

                for segment in segments:
                    # Normalize the segment
                    normalized_segment = normalize_piano_roll(segment)

                    # Extract musical features
                    features = extract_musical_features(normalized_segment)

                    # Quality check: skip if too sparse
                    note_density = features['avg_notes_per_time']
                    if note_density < 0.1:  # Very sparse, likely poor quality
                        continue

                    composer_data.append(normalized_segment)
                    composer_labels.append(composer_to_idx[composer])
                    composer_features.append(features)
                    segments_created += 1

                files_processed += 1

                if files_processed % 10 == 0:
                    print(f"  Processed {files_processed} files, created {segments_created} segments...")

            except Exception as e:
                print(f"  Error processing {file}: {e}")
                continue

        print(f"✅ {composer}: {files_processed} files → {segments_created} segments")

        if composer_data:
            composer_data = np.array(composer_data)
            composer_labels = np.array(composer_labels)

            all_data.append(composer_data)
            all_labels.append(composer_labels)
            all_features.extend(composer_features)

            print(f"  Final data shape: {composer_data.shape}")

    # Combine all data
    if all_data:
        data = np.concatenate(all_data, axis=0)
        labels = np.concatenate(all_labels, axis=0)

        print(f"\n🎯 FINAL IMPROVED DATASET:")
        print(f"Total samples: {len(data)}")
        print(f"Data shape: {data.shape}")
        print(f"Label distribution: {np.bincount(labels)}")

        return data, labels, all_features
    else:
        print("❌ No data loaded!")
        return None, None, None

# Load the improved dataset
print("🚀 Starting improved data loading...")
improved_data, improved_labels, features = load_improved_dataset(
    extract_path,
    TARGET_COMPOSERS,
    target_duration=45.0,  # 45 seconds per segment
    max_files_per_composer=120  # Limit for testing - remove for full dataset
)

In [None]:
# =====================================================
# IMPROVED NON-OVERLAPPING SEGMENTATION
# =====================================================

from collections import defaultdict

def balance_classes_song_aware(segments_by_composer, target_samples_per_class=None):
    """
    Balance classes by selecting complete songs (with all their segments)
    This preserves temporal relationships within songs
    """
    # Calculate current distribution
    composer_counts = {composer: len(segments) for composer, segments in segments_by_composer.items()}
    max_count = max(composer_counts.values()) if target_samples_per_class is None else target_samples_per_class
    
    print(f"📊 Original segment distribution: {composer_counts}")
    print(f"🎯 Target segments per class: {max_count}")
    
    balanced_segments = {}
    
    for composer, segments in segments_by_composer.items():
        current_count = len(segments)
        
        if current_count >= max_count:
            # If we have enough, randomly select songs to reach target
            songs_dict = defaultdict(list)
            for segment in segments:
                songs_dict[segment['song_id']].append(segment)
            
            # Randomly select songs until we reach target count
            selected_segments = []
            song_ids = list(songs_dict.keys())
            np.random.shuffle(song_ids)
            
            for song_id in song_ids:
                song_segments = songs_dict[song_id]
                if len(selected_segments) + len(song_segments) <= max_count:
                    selected_segments.extend(song_segments)
                elif len(selected_segments) < max_count:
                    # Partial song - take contiguous segments from the beginning
                    needed = max_count - len(selected_segments)
                    selected_segments.extend(song_segments[:needed])
                    break
            
            balanced_segments[composer] = selected_segments
            print(f"  {composer}: {current_count} → {len(selected_segments)} (downsampled)")
            
        else:
            # Need to oversample - duplicate entire songs
            needed_samples = max_count - current_count
            
            # Group segments by song
            songs_dict = defaultdict(list)
            for segment in segments:
                songs_dict[segment['song_id']].append(segment)
            
            song_ids = list(songs_dict.keys())
            selected_segments = segments.copy()  # Start with all original segments
            
            # Add complete songs until we reach target
            while len(selected_segments) < max_count:
                # Randomly select a song to duplicate
                song_id = np.random.choice(song_ids)
                song_segments = songs_dict[song_id]
                
                if len(selected_segments) + len(song_segments) <= max_count:
                    # Add entire song
                    for segment in song_segments:
                        # Create a copy with new metadata to avoid conflicts
                        new_segment = segment.copy()
                        new_segment['song_id'] = f"{segment['song_id']}_dup_{len(selected_segments)}"
                        selected_segments.append(new_segment)
                else:
                    # Add partial song if needed
                    remaining = max_count - len(selected_segments)
                    for i, segment in enumerate(song_segments[:remaining]):
                        new_segment = segment.copy()
                        new_segment['song_id'] = f"{segment['song_id']}_dup_{len(selected_segments)}"
                        selected_segments.append(new_segment)
                    break
            
            balanced_segments[composer] = selected_segments
            print(f"  {composer}: {current_count} → {len(selected_segments)} (+{len(selected_segments) - current_count} from song duplication)")
    
    return balanced_segments

def get_piano_roll_segments_no_overlap(midi_path, fs=100, segment_duration=20.0):
    """
    Extract NON-OVERLAPPING segments from a single MIDI file
    This preserves temporal relationships without data leakage
    """
    try:
        pm = pretty_midi.PrettyMIDI(midi_path)
        actual_duration = pm.get_end_time()
        
        # Skip very short pieces
        if actual_duration < segment_duration:
            return None
            
        segments = []
        segment_length = int(segment_duration * fs)
        
        # Extract non-overlapping segments
        current_time = 0.0
        segment_idx = 0
        
        while current_time + segment_duration <= actual_duration:
            end_time = current_time + segment_duration
            
            # Create segment
            pm_segment = pretty_midi.PrettyMIDI()
            for instrument in pm.instruments:
                new_instrument = pretty_midi.Instrument(
                    program=instrument.program,
                    is_drum=instrument.is_drum,
                    name=instrument.name
                )
                
                for note in instrument.notes:
                    if current_time <= note.start < end_time:
                        new_note = pretty_midi.Note(
                            velocity=note.velocity,
                            pitch=note.pitch,
                            start=note.start - current_time,
                            end=min(note.end - current_time, segment_duration)
                        )
                        new_instrument.notes.append(new_note)
                
                if new_instrument.notes:
                    pm_segment.instruments.append(new_instrument)
            
            # Convert to piano roll
            if pm_segment.instruments:
                piano_roll = pm_segment.get_piano_roll(fs=fs)
                
                # Ensure exact length
                if piano_roll.shape[1] > segment_length:
                    piano_roll = piano_roll[:, :segment_length]
                elif piano_roll.shape[1] < segment_length:
                    pad_width = segment_length - piano_roll.shape[1]
                    piano_roll = np.pad(piano_roll, ((0,0),(0,pad_width)), mode='constant')
                
                # Store segment with metadata
                segment_info = {
                    'piano_roll': piano_roll,
                    'song_id': midi_path,
                    'segment_idx': segment_idx,
                    'start_time': current_time
                }
                segments.append(segment_info)
                segment_idx += 1
            
            # Move to next segment (NO OVERLAP)
            current_time += segment_duration  # Changed from (segment_duration - overlap)
        
        return segments if segments else None
        
    except Exception as e:
        print(f"Error processing {midi_path}: {e}")
        return None

def load_segmented_dataset_no_overlap(extract_path, target_composers, segment_duration=20.0, 
                                     max_files_per_composer=None, balance_classes=True):
    """
    Load dataset with NON-OVERLAPPING segmentation and balancing
    """
    print("🎵 LOADING DATASET WITH NON-OVERLAPPING SEGMENTATION...")
    print(f"Segment duration: {segment_duration}s with NO OVERLAP")
    print(f"Balance classes: {balance_classes}")
    
    composer_to_idx = {c: i for i, c in enumerate(target_composers)}
    segments_by_composer = {composer: [] for composer in target_composers}
    
    for composer in target_composers:
        print(f"\n--- Processing {composer} ---")
        composer_dir = os.path.join(extract_path, composer)
        
        if not os.path.isdir(composer_dir):
            print(f"Directory not found: {composer_dir}")
            continue
        
        files_processed = 0
        total_segments = 0
        
        midi_files = [f for f in os.listdir(composer_dir)
                     if f.lower().endswith(('.mid', '.midi'))]
        
        if max_files_per_composer:
            midi_files = midi_files[:max_files_per_composer]
        
        for file in midi_files:
            midi_path = os.path.join(composer_dir, file)
            
            try:
                # Extract NON-OVERLAPPING segments from this file
                segments = get_piano_roll_segments_no_overlap(
                    midi_path, 
                    segment_duration=segment_duration
                )
                
                if segments is None:
                    continue
                
                # Process each segment
                valid_segments = []
                for segment_info in segments:
                    piano_roll = segment_info['piano_roll']
                    
                    # Normalize the segment
                    normalized_segment = normalize_piano_roll(piano_roll)
                    
                    # Extract musical features
                    features = extract_musical_features(normalized_segment)
                    
                    # Quality check: skip if too sparse
                    note_density = features['avg_notes_per_time']
                    if note_density < 0.05:  # Very sparse, likely poor quality
                        continue
                    
                    # Update segment info
                    segment_info['piano_roll'] = normalized_segment
                    segment_info['features'] = features
                    segment_info['label'] = composer_to_idx[composer]
                    
                    valid_segments.append(segment_info)
                
                # Add valid segments to composer collection
                segments_by_composer[composer].extend(valid_segments)
                total_segments += len(valid_segments)
                files_processed += 1
                
                if files_processed % 10 == 0:
                    print(f"  Processed {files_processed} files, created {total_segments} segments...")
                    
            except Exception as e:
                print(f"  Error processing {file}: {e}")
                continue
        
        print(f"✅ {composer}: {files_processed} files → {total_segments} segments")
    
    # Balance classes if requested
    if balance_classes:
        print(f"\n⚖️ BALANCING CLASSES (SONG-AWARE, NO OVERLAP)...")
        segments_by_composer = balance_classes_song_aware(segments_by_composer)
    
    # Convert to arrays
    all_data = []
    all_labels = []
    all_features = []
    
    for composer, segments in segments_by_composer.items():
        for segment_info in segments:
            all_data.append(segment_info['piano_roll'])
            all_labels.append(segment_info['label'])
            all_features.append(segment_info['features'])
    
    data = np.array(all_data)
    labels = np.array(all_labels)
    
    print(f"\n📊 FINAL BALANCED DATASET (NO OVERLAP):")
    print(f"Total samples: {len(data)}")
    print(f"Data shape: {data.shape}")
    final_counts = np.bincount(labels)
    for i, composer in enumerate(target_composers):
        if i < len(final_counts):
            print(f"  {composer}: {final_counts[i]} samples")
    
    return data, labels, all_features

# Load the NON-OVERLAPPING segmented and balanced dataset
print("🚀 Starting NON-OVERLAPPING segmented data loading...")
segmented_data, segmented_labels, segmented_features = load_segmented_dataset_no_overlap(
    extract_path,
    TARGET_COMPOSERS,
    segment_duration=20.0,      # 20-second segments
    max_files_per_composer=150, # More files since we get fewer segments per song
    balance_classes=True        # Enable song-aware class balancing
)

In [None]:
# =====================================================
# ENHANCED MUSICAL FEATURE EXTRACTION FOR MULTIMODAL
# =====================================================

def extract_comprehensive_musical_features(piano_roll):
    """
    Extract comprehensive musical features for the MLP stream
    """
    features = {}
    
    # Get basic timeline and pitch data
    note_density_timeline = np.sum(piano_roll > 0, axis=0)
    pitch_activity = np.sum(piano_roll > 0, axis=1)
    active_pitches = pitch_activity > 0
    
    # ==========================================
    # TEMPORAL/RHYTHMIC FEATURES
    # ==========================================
    
    # Basic temporal statistics
    features['avg_notes_per_time'] = np.mean(note_density_timeline)
    features['note_density_variance'] = np.var(note_density_timeline)
    features['note_density_std'] = np.std(note_density_timeline)
    features['max_simultaneous_notes'] = np.max(note_density_timeline)
    
    # Rhythmic complexity
    onset_pattern = np.diff(note_density_timeline > 0).astype(int)
    features['onset_density'] = np.sum(onset_pattern == 1) / len(onset_pattern) if len(onset_pattern) > 0 else 0
    features['silence_ratio'] = np.sum(note_density_timeline == 0) / len(note_density_timeline)
    
    # Temporal distribution analysis
    active_frames = note_density_timeline > 0
    if np.any(active_frames):
        features['temporal_sparsity'] = 1 - (np.sum(active_frames) / len(active_frames))
        
        # Find note clusters (bursts of activity)
        cluster_changes = np.diff(active_frames.astype(int))
        features['activity_bursts'] = np.sum(cluster_changes == 1) / len(cluster_changes) if len(cluster_changes) > 0 else 0
    else:
        features['temporal_sparsity'] = 1.0
        features['activity_bursts'] = 0.0
    
    # ==========================================
    # PITCH/HARMONIC FEATURES
    # ==========================================
    
    if np.any(active_pitches):
        # Basic pitch statistics
        features['pitch_range'] = np.sum(active_pitches)
        features['lowest_pitch'] = np.argmax(active_pitches)
        features['highest_pitch'] = 127 - np.argmax(active_pitches[::-1])
        features['pitch_span'] = features['highest_pitch'] - features['lowest_pitch']
        
        # Pitch distribution
        weighted_pitches = np.arange(128) * pitch_activity
        total_weight = np.sum(pitch_activity)
        if total_weight > 0:
            features['pitch_centroid'] = np.sum(weighted_pitches) / total_weight
            features['pitch_variance'] = np.var(pitch_activity[active_pitches])
        else:
            features['pitch_centroid'] = 60  # Middle C
            features['pitch_variance'] = 0
            
        # Register analysis (musical ranges)
        bass_range = pitch_activity[21:48]  # A0 to B2
        mid_range = pitch_activity[48:72]   # C3 to B4  
        treble_range = pitch_activity[72:108] # C5 to B7
        
        total_activity = np.sum(pitch_activity)
        features['bass_activity'] = np.sum(bass_range) / total_activity if total_activity > 0 else 0
        features['mid_activity'] = np.sum(mid_range) / total_activity if total_activity > 0 else 0
        features['treble_activity'] = np.sum(treble_range) / total_activity if total_activity > 0 else 0
        
    else:
        # Default values for empty piano rolls
        features.update({
            'pitch_range': 0, 'lowest_pitch': 60, 'highest_pitch': 60,
            'pitch_span': 0, 'pitch_centroid': 60, 'pitch_variance': 0,
            'bass_activity': 0, 'mid_activity': 0, 'treble_activity': 0
        })
    
    # ==========================================
    # HARMONIC COMPLEXITY FEATURES
    # ==========================================
    
    # Analyze simultaneous note patterns (chords vs single notes)
    chord_frames = note_density_timeline >= 3  # 3+ simultaneous notes = chord
    single_note_frames = note_density_timeline == 1
    
    features['chord_ratio'] = np.sum(chord_frames) / len(note_density_timeline)
    features['single_note_ratio'] = np.sum(single_note_frames) / len(note_density_timeline)
    features['polyphony_complexity'] = np.mean(note_density_timeline[note_density_timeline > 0]) if np.any(note_density_timeline > 0) else 0
    
    # Chord complexity analysis
    if np.sum(chord_frames) > 0:
        chord_complexities = note_density_timeline[chord_frames]
        features['avg_chord_size'] = np.mean(chord_complexities)
        features['chord_variance'] = np.var(chord_complexities)
    else:
        features['avg_chord_size'] = 0
        features['chord_variance'] = 0
    
    # ==========================================
    # VELOCITY/DYNAMICS FEATURES  
    # ==========================================
    
    if np.any(piano_roll > 0):
        velocities = piano_roll[piano_roll > 0]
        features['avg_velocity'] = np.mean(velocities)
        features['velocity_variance'] = np.var(velocities)
        features['velocity_range'] = np.max(velocities) - np.min(velocities)
        features['dynamic_complexity'] = len(np.unique(velocities)) / len(velocities)
    else:
        features.update({
            'avg_velocity': 0, 'velocity_variance': 0,
            'velocity_range': 0, 'dynamic_complexity': 0
        })
    
    # ==========================================
    # STYLE-SPECIFIC FEATURES
    # ==========================================
    
    # Measure musical "busyness" vs "spaciousness"
    features['overall_density'] = np.sum(piano_roll > 0) / piano_roll.size
    
    # Temporal consistency (how regular/irregular the rhythm is)
    if len(note_density_timeline) > 1:
        features['rhythmic_regularity'] = 1 / (1 + np.var(note_density_timeline))
    else:
        features['rhythmic_regularity'] = 1.0
        
    # Pitch movement patterns
    if np.sum(active_pitches) > 1:
        pitch_centers = []
        for t in range(piano_roll.shape[1]):
            frame = piano_roll[:, t]
            if np.any(frame > 0):
                weighted_pitch = np.sum(np.arange(128) * frame) / np.sum(frame)
                pitch_centers.append(weighted_pitch)
        
        if len(pitch_centers) > 1:
            pitch_movement = np.diff(pitch_centers)
            features['pitch_movement_variance'] = np.var(pitch_movement)
            features['melodic_direction_changes'] = np.sum(np.diff(np.sign(pitch_movement)) != 0) / len(pitch_movement) if len(pitch_movement) > 0 else 0
        else:
            features['pitch_movement_variance'] = 0
            features['melodic_direction_changes'] = 0
    else:
        features['pitch_movement_variance'] = 0
        features['melodic_direction_changes'] = 0
    
    return features

# Test the enhanced feature extraction
print("🧪 Testing enhanced musical feature extraction...")
if 'improved_data' in globals() and improved_data is not None:
    sample_piano_roll = improved_data[0]
    enhanced_features = extract_comprehensive_musical_features(sample_piano_roll)
    
    print(f"✅ Enhanced features extracted!")
    print(f"Number of features: {len(enhanced_features)}")
    print(f"Feature names: {list(enhanced_features.keys())}")
    print(f"Sample values: {list(enhanced_features.values())[:5]}...")
else:
    print("⏳ Run data loading first to test feature extraction")

print("\n🎯 ENHANCED FEATURES READY FOR MULTIMODAL!")
print("Features include:")
print("• Temporal/Rhythmic: density, sparsity, bursts")  
print("• Pitch/Harmonic: range, centroid, register distribution")
print("• Harmonic Complexity: chord ratios, polyphony")
print("• Dynamics: velocity patterns, expression")
print("• Style: movement patterns, regularity")

In [None]:
# =====================================================
# MULTIMODAL DATASET CLASS
# =====================================================

class MultimodalDataset(Dataset):
    """
    Dataset class that handles both piano rolls (for CNN) and musical features (for MLP)
    """
    def __init__(self, piano_rolls, features_list, labels):
        # Convert piano rolls to tensor
        self.piano_rolls = torch.tensor(piano_rolls, dtype=torch.float32)
        
        # Convert feature dictionaries to feature vectors
        self.features = self._process_features(features_list)
        
        # Convert labels to tensor
        self.labels = torch.tensor(labels, dtype=torch.long)
        
        print(f"✅ Multimodal Dataset Created:")
        print(f"• Piano rolls: {self.piano_rolls.shape}")
        print(f"• Features: {self.features.shape}")
        print(f"• Labels: {self.labels.shape}")
        print(f"• Total samples: {len(self.labels)}")
    
    def _process_features(self, features_list):
        """Convert list of feature dictionaries to tensor"""
        # Get feature names from first sample
        feature_names = list(features_list[0].keys())
        
        # Extract feature values for all samples
        feature_matrix = []
        for feature_dict in features_list:
            feature_vector = [feature_dict[name] for name in feature_names]
            feature_matrix.append(feature_vector)
        
        # Convert to tensor and normalize
        features_tensor = torch.tensor(feature_matrix, dtype=torch.float32)
        
        # Normalize features (important for MLP training)
        features_mean = features_tensor.mean(dim=0)
        features_std = features_tensor.std(dim=0)
        features_std[features_std == 0] = 1  # Avoid division by zero
        features_normalized = (features_tensor - features_mean) / features_std
        
        print(f"📊 Feature Processing:")
        print(f"• Feature names: {feature_names[:5]}...")
        print(f"• Features normalized: mean≈0, std≈1")
        
        return features_normalized
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        # Piano roll with channel dimension for CNN: (1, 128, T)
        piano_roll = self.piano_rolls[idx].unsqueeze(0)
        
        # Feature vector for MLP: (num_features,)
        features = self.features[idx]
        
        # Label
        label = self.labels[idx]
        
        return piano_roll, features, label

# Create enhanced multimodal dataset with comprehensive features
print("🎵 CREATING MULTIMODAL DATASET...")
print("Extracting comprehensive features for all samples...")

# Extract comprehensive features for all loaded data
if improved_data is not None:
    print("Extracting comprehensive features for multimodal training...")
    comprehensive_features = []
    
    for i, piano_roll in enumerate(improved_data):
        features = extract_comprehensive_musical_features(piano_roll)
        comprehensive_features.append(features)
        
        if (i + 1) % 100 == 0:
            print(f"  Processed {i + 1}/{len(improved_data)} samples...")
    
    # Create multimodal dataset
    multimodal_dataset = MultimodalDataset(
        piano_rolls=improved_data,
        features_list=comprehensive_features,
        labels=improved_labels
    )
    
    print("\n🎯 MULTIMODAL DATASET READY!")
    print("Ready for dual-stream architecture training!")
    
else:
    print("❌ No data loaded - run data loading first!")

In [None]:
# =====================================================
# CREATE MULTIMODAL DATASET FROM SEGMENTED DATA
# =====================================================

# Create enhanced multimodal dataset with comprehensive features from segmented data
print("🎵 CREATING MULTIMODAL DATASET FROM SEGMENTED DATA...")

if segmented_data is not None:
    print("Extracting comprehensive features for multimodal training...")
    comprehensive_features = []
    
    for i, piano_roll in enumerate(segmented_data):
        features = extract_comprehensive_musical_features(piano_roll)
        comprehensive_features.append(features)
        
        if (i + 1) % 100 == 0:
            print(f"  Processed {i + 1}/{len(segmented_data)} samples...")
    
    # Create multimodal dataset with segmented, balanced data
    multimodal_dataset = MultimodalDataset(
        piano_rolls=segmented_data,
        features_list=comprehensive_features,
        labels=segmented_labels
    )
    
    print("\n🎯 BALANCED MULTIMODAL DATASET READY!")
    print("Using segmented data with temporal consistency preserved!")
    
else:
    print("❌ No segmented data loaded - run segmentation first!")

In [None]:
# =====================================================
# MULTIMODAL CNN-MLP FUSION ARCHITECTURE
# =====================================================

class MultimodalComposerClassifier(nn.Module):
    """
    Multimodal architecture combining CNN (piano rolls) + MLP (musical features)
    """
    def __init__(self, num_classes=4, num_features=None):
        super(MultimodalComposerClassifier, self).__init__()
        
        # CNN Stream for Piano Rolls (visual patterns)
        self.cnn_stream = nn.Sequential(
            # Block 1
            nn.Conv2d(1, 32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.MaxPool2d(kernel_size=(2, 2)),
            
            # Block 2
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout2d(0.3),
            nn.MaxPool2d(kernel_size=(2, 2)),
            
            # Block 3
            nn.Conv2d(64, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout2d(0.4),
            nn.MaxPool2d(kernel_size=(2, 2)),
            
            # Global pooling instead of LSTM
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten()
        )
        
        # MLP Stream for Musical Features (hand-crafted features)
        self.mlp_stream = nn.Sequential(
            nn.Linear(num_features, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        
        # Fusion Layer (combine both streams)
        self.fusion = nn.Sequential(
            nn.Linear(128 + 32, 256),  # 128 from CNN + 32 from MLP
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, num_classes)
        )
        
    def forward(self, piano_roll, features):
        # CNN stream processes piano rolls
        cnn_features = self.cnn_stream(piano_roll)
        
        # MLP stream processes musical features
        mlp_features = self.mlp_stream(features)
        
        # Concatenate both feature streams
        combined = torch.cat([cnn_features, mlp_features], dim=1)
        
        # Final classification
        output = self.fusion(combined)
        
        return output

print("🏗️ Multimodal Architecture Defined!")
print("• CNN Stream: Processes piano roll visual patterns")
print("• MLP Stream: Processes hand-crafted musical features") 
print("• Fusion Layer: Combines both streams for classification")
print("• No LSTM: Uses global pooling for better efficiency")

In [None]:
# =====================================================
# TRAINING SETUP & EXECUTION
# =====================================================

from torch.utils.data import DataLoader, random_split
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt

# Create model
num_features = len(comprehensive_features[0])  # Number of musical features
model = MultimodalComposerClassifier(
    num_classes=4, 
    num_features=num_features
).to(device)

print(f"Model created with {num_features} musical features")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Train/validation split
train_size = int(0.8 * len(multimodal_dataset))
val_size = len(multimodal_dataset) - train_size
train_dataset, val_dataset = random_split(multimodal_dataset, [train_size, val_size])

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.7)

# Training function
def train_multimodal_model(model, train_loader, val_loader, epochs=25):
    train_losses = []
    val_accuracies = []
    
    for epoch in range(epochs):
        # Training
        model.train()
        epoch_loss = 0.0
        
        for piano_rolls, features, labels in train_loader:
            piano_rolls, features, labels = piano_rolls.to(device), features.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(piano_rolls, features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        # Validation
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for piano_rolls, features, labels in val_loader:
                piano_rolls, features, labels = piano_rolls.to(device), features.to(device), labels.to(device)
                outputs = model(piano_rolls, features)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        val_accuracy = 100 * correct / total
        avg_loss = epoch_loss / len(train_loader)
        
        train_losses.append(avg_loss)
        val_accuracies.append(val_accuracy)
        
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, Val Acc: {val_accuracy:.2f}%')
        
        scheduler.step()
    
    return train_losses, val_accuracies

# Start training
print("🚀 Starting multimodal training...")
train_losses, val_accuracies = train_multimodal_model(model, train_loader, val_loader, epochs=25)