In [8]:
import os
import kagglehub
import zipfile
import shutil
import numpy as np
import torch
import torch.nn as nn
import os
import numpy as np
import pretty_midi
import torch
from torch.utils.data import Dataset

In [2]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [20]:
if not hasattr(np, 'int'):
    np.int = int

In [4]:
# Define your model class (must match the architecture used in final-project1)
class CNN_LSTM_Classifier(nn.Module):
    def __init__(self, num_classes=4, lstm_hidden=256):
        super(CNN_LSTM_Classifier, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.MaxPool2d(kernel_size=(2, 2))
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout2d(0.3),
            nn.MaxPool2d(kernel_size=(2, 2))
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout2d(0.4),
            nn.MaxPool2d(kernel_size=(2, 2))
        )
        self.feature_size = 128 * 16
        self.lstm = nn.LSTM(
            input_size=self.feature_size,
            hidden_size=lstm_hidden,
            num_layers=2,
            batch_first=True,
            dropout=0.3,
            bidirectional=True
        )
        self.attention = nn.MultiheadAttention(
            embed_dim=lstm_hidden * 2,
            num_heads=8,
            dropout=0.3,
            batch_first=True
        )
        self.classifier = nn.Sequential(
            nn.Linear(lstm_hidden * 2, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        batch_size = x.size(0)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.permute(0, 3, 1, 2)
        x = x.contiguous().view(batch_size, x.size(1), -1)
        lstm_out, _ = self.lstm(x)
        attn_out, _ = self.attention(lstm_out, lstm_out, lstm_out)
        pooled = torch.mean(attn_out, dim=1)
        output = self.classifier(pooled)
        return output


In [6]:

# Load the models
original_path = os.path.join('saved_models', 'original_cnn_lstm.pth')
rhythm_path = os.path.join('saved_models', 'rhythm_augmented_cnn_lstm.pth')

model = CNN_LSTM_Classifier(num_classes=4, lstm_hidden=256).to(device)
model.load_state_dict(torch.load(original_path, map_location=device))
print(f"✅ Loaded: {original_path}")

rhythm_model = CNN_LSTM_Classifier(num_classes=4, lstm_hidden=256).to(device)
rhythm_model.load_state_dict(torch.load(rhythm_path, map_location=device))
print(f"✅ Loaded: {rhythm_path}")

✅ Loaded: saved_models/original_cnn_lstm.pth
✅ Loaded: saved_models/rhythm_augmented_cnn_lstm.pth


In [14]:
TARGET_COMPOSERS = [
    'Bach',
    'Beethoven',
    'Chopin',
    'Mozart',
]

path = kagglehub.dataset_download("blanderbuss/midi-classic-music")

zip_path = os.path.join(path, 'midiclassics.zip')
extract_path = os.path.join('data', 'kaggle', 'midiclassics')


with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

# list files in extract_path that contain the target composers in name
for composer in TARGET_COMPOSERS:
    composer_files = [f for f in os.listdir(extract_path) if composer.lower() in f.lower()]

# Only keep directories that contain a target composer's name
for item in os.listdir(extract_path):
    item_path = os.path.join(extract_path, item)
    if not any(composer.lower() in item.lower() for composer in TARGET_COMPOSERS):
        if os.path.isfile(item_path):
            os.remove(item_path)
        elif os.path.isdir(item_path):
            shutil.rmtree(item_path)

# also delete "C.P.E.Bach" files. This was the son of J.S. Bach, and we want to keep only the main composers
for item in os.listdir(extract_path):
    if 'C.P.E.Bach' in item:
        item_path = os.path.join(extract_path, item)
        if os.path.isfile(item_path):
            os.remove(item_path)
        elif os.path.isdir(item_path):
            shutil.rmtree(item_path)

In [15]:
class PianoRollDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        # Add channel dimension for CNN: (1, 128, T)
        return self.data[idx].unsqueeze(0), self.labels[idx]

In [16]:
def get_piano_roll_improved(midi_path, fs=100, target_duration=45.0):
    """
    Improved MIDI to piano roll conversion with musical awareness
    """
    try:
        pm = pretty_midi.PrettyMIDI(midi_path)
        
        # Get the actual duration of the piece
        actual_duration = pm.get_end_time()
        
        # If piece is very short, skip it
        if actual_duration < 10.0:  # Less than 10 seconds
            return None
            
        # For long pieces, extract multiple segments
        if actual_duration > target_duration * 1.5:
            # Extract from different parts of the piece
            segments = []
            num_segments = min(3, int(actual_duration // target_duration))
            
            for i in range(num_segments):
                start_time = i * (actual_duration / num_segments)
                end_time = start_time + target_duration
                
                # Create a copy and trim
                pm_segment = pretty_midi.PrettyMIDI()
                for instrument in pm.instruments:
                    new_instrument = pretty_midi.Instrument(
                        program=instrument.program,
                        is_drum=instrument.is_drum,
                        name=instrument.name
                    )
                    
                    for note in instrument.notes:
                        if start_time <= note.start < end_time:
                            new_note = pretty_midi.Note(
                                velocity=note.velocity,
                                pitch=note.pitch,
                                start=note.start - start_time,
                                end=min(note.end - start_time, target_duration)
                            )
                            new_instrument.notes.append(new_note)
                    
                    if new_instrument.notes:
                        pm_segment.instruments.append(new_instrument)
                
                if pm_segment.instruments:
                    piano_roll = pm_segment.get_piano_roll(fs=fs)
                    target_length = int(target_duration * fs)
                    
                    if piano_roll.shape[1] > target_length:
                        piano_roll = piano_roll[:, :target_length]
                    else:
                        pad_width = target_length - piano_roll.shape[1]
                        piano_roll = np.pad(piano_roll, ((0,0),(0,pad_width)), mode='constant')
                    
                    segments.append(piano_roll)
            
            return segments
        
        else:
            # For normal length pieces, use the whole piece
            piano_roll = pm.get_piano_roll(fs=fs)
            target_length = int(target_duration * fs)
            
            if piano_roll.shape[1] > target_length:
                # Take from the middle rather than truncating end
                start_idx = (piano_roll.shape[1] - target_length) // 2
                piano_roll = piano_roll[:, start_idx:start_idx + target_length]
            else:
                pad_width = target_length - piano_roll.shape[1]
                piano_roll = np.pad(piano_roll, ((0,0),(0,pad_width)), mode='constant')
            
            return [piano_roll]
            
    except Exception as e:
        print(f"Error processing {midi_path}: {e}")
        return None


In [17]:

def normalize_piano_roll(piano_roll):
    """
    Apply musical normalization to piano roll
    """
    # 1. Velocity normalization (already 0-1 from pretty_midi)
    normalized = piano_roll.copy()
    
    # 2. Optional: Focus on active pitch range
    active_pitches = np.any(normalized > 0, axis=1)
    if np.any(active_pitches):
        first_active = np.argmax(active_pitches)
        last_active = len(active_pitches) - 1 - np.argmax(active_pitches[::-1])
        
        # Ensure we keep a reasonable range (at least 60 semitones = 5 octaves)
        min_range = 60
        current_range = last_active - first_active + 1
        
        if current_range < min_range:
            expand = (min_range - current_range) // 2
            first_active = max(0, first_active - expand)
            last_active = min(127, last_active + expand)
    
    return normalized


In [18]:

def extract_musical_features(piano_roll):
    """
    Extract features that capture musical style
    """
    features = {}
    
    # Temporal features
    note_density_timeline = np.sum(piano_roll > 0, axis=0)
    features['avg_notes_per_time'] = np.mean(note_density_timeline)
    features['note_density_variance'] = np.var(note_density_timeline)
    
    # Pitch features
    pitch_activity = np.sum(piano_roll > 0, axis=1)
    active_pitches = pitch_activity > 0
    if np.any(active_pitches):
        features['pitch_range'] = np.sum(active_pitches)
        features['lowest_pitch'] = np.argmax(active_pitches)
        features['highest_pitch'] = 127 - np.argmax(active_pitches[::-1])
    else:
        features['pitch_range'] = 0
        features['lowest_pitch'] = 60  # Middle C
        features['highest_pitch'] = 60
    
    # Rhythmic features
    onset_pattern = np.diff(note_density_timeline > 0).astype(int)
    features['onset_density'] = np.sum(onset_pattern == 1) / len(onset_pattern)
    
    return features

print("✅ Improved data processing functions defined!")
print("Key improvements:")
print("• Intelligent segment extraction for long pieces")
print("• Musical boundary awareness")
print("• Better normalization")
print("• Feature extraction for analysis")

✅ Improved data processing functions defined!
Key improvements:
• Intelligent segment extraction for long pieces
• Musical boundary awareness
• Better normalization
• Feature extraction for analysis


In [21]:
# =====================================================
# IMPROVED DATA LOADING WITH BETTER PROCESSING
# =====================================================

import numpy as np

def load_improved_dataset(extract_path, target_composers, target_duration=45.0, max_files_per_composer=None):
    """
    Load dataset with improved processing that addresses previous shortcomings
    """
    print("🎵 LOADING DATASET WITH IMPROVED PROCESSING...")
    print("Improvements over original:")
    print("• Intelligent segment extraction for long pieces")
    print("• Better handling of piece lengths")
    print("• Musical feature extraction")
    print("• Quality filtering")
    
    composer_to_idx = {c: i for i, c in enumerate(target_composers)}
    all_data = []
    all_labels = []
    all_features = []
    
    for composer in target_composers:
        print(f"\n--- Processing {composer} ---")
        composer_dir = os.path.join(extract_path, composer)
        
        if not os.path.isdir(composer_dir):
            print(f"Directory not found: {composer_dir}")
            continue
            
        composer_data = []
        composer_labels = []
        composer_features = []
        files_processed = 0
        segments_created = 0
        
        midi_files = [f for f in os.listdir(composer_dir) 
                     if f.lower().endswith(('.mid', '.midi'))]
        
        if max_files_per_composer:
            midi_files = midi_files[:max_files_per_composer]
        
        for file in midi_files:
            midi_path = os.path.join(composer_dir, file)
            
            try:
                # Use improved processing
                segments = get_piano_roll_improved(midi_path, target_duration=target_duration)
                
                if segments is None:
                    continue
                    
                for segment in segments:
                    # Normalize the segment
                    normalized_segment = normalize_piano_roll(segment)
                    
                    # Extract musical features
                    features = extract_musical_features(normalized_segment)
                    
                    # Quality check: skip if too sparse
                    note_density = features['avg_notes_per_time']
                    if note_density < 0.1:  # Very sparse, likely poor quality
                        continue
                    
                    composer_data.append(normalized_segment)
                    composer_labels.append(composer_to_idx[composer])
                    composer_features.append(features)
                    segments_created += 1
                
                files_processed += 1
                
                if files_processed % 10 == 0:
                    print(f"  Processed {files_processed} files, created {segments_created} segments...")
                    
            except Exception as e:
                print(f"  Error processing {file}: {e}")
                continue
        
        print(f"✅ {composer}: {files_processed} files → {segments_created} segments")
        
        if composer_data:
            composer_data = np.array(composer_data)
            composer_labels = np.array(composer_labels)
            
            all_data.append(composer_data)
            all_labels.append(composer_labels) 
            all_features.extend(composer_features)
            
            print(f"  Final data shape: {composer_data.shape}")
    
    # Combine all data
    if all_data:
        data = np.concatenate(all_data, axis=0)
        labels = np.concatenate(all_labels, axis=0)
        
        print(f"\n🎯 FINAL IMPROVED DATASET:")
        print(f"Total samples: {len(data)}")
        print(f"Data shape: {data.shape}")
        print(f"Label distribution: {np.bincount(labels)}")
        
        return data, labels, all_features
    else:
        print("❌ No data loaded!")
        return None, None, None

# Load the improved dataset
print("🚀 Starting improved data loading...")
improved_data, improved_labels, features = load_improved_dataset(
    extract_path, 
    TARGET_COMPOSERS,
    target_duration=45.0,  # 45 seconds per segment
    max_files_per_composer=50  # Limit for testing - remove for full dataset
)

🚀 Starting improved data loading...
🎵 LOADING DATASET WITH IMPROVED PROCESSING...
Improvements over original:
• Intelligent segment extraction for long pieces
• Better handling of piece lengths
• Musical feature extraction
• Quality filtering

--- Processing Bach ---
  Processed 10 files, created 27 segments...
  Processed 20 files, created 54 segments...
  Processed 30 files, created 82 segments...
  Processed 40 files, created 110 segments...
  Processed 50 files, created 131 segments...
✅ Bach: 50 files → 131 segments
  Final data shape: (131, 128, 4500)

--- Processing Beethoven ---
  Processed 10 files, created 27 segments...
  Processed 20 files, created 55 segments...
  Processed 30 files, created 81 segments...
  Processed 40 files, created 104 segments...
  Processed 50 files, created 128 segments...
✅ Beethoven: 50 files → 128 segments
  Final data shape: (128, 128, 4500)

--- Processing Chopin ---
  Processed 10 files, created 23 segments...
  Processed 20 files, created 44 

In [22]:
# =====================================================
# TEST TRAINED MODELS ON IMPROVED DATASET
# =====================================================

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# Create dataset and dataloader for improved data
improved_dataset = PianoRollDataset(improved_data, improved_labels)
test_loader = DataLoader(improved_dataset, batch_size=32, shuffle=False)

def evaluate_model(model, dataloader, model_name):
    """Evaluate a single model"""
    model.eval()
    all_preds = []
    all_targets = []
    all_probs = []
    
    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            probs = F.softmax(outputs, dim=1)
            preds = torch.argmax(outputs, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    accuracy = 100 * np.mean(np.array(all_preds) == np.array(all_targets))
    
    print(f"\n📊 {model_name} Results:")
    print(f"Accuracy: {accuracy:.2f}%")
    print("\nClassification Report:")
    print(classification_report(all_targets, all_preds, 
                              target_names=TARGET_COMPOSERS, 
                              digits=3))
    
    return all_preds, all_targets, all_probs, accuracy

# Test both models
print("🧪 TESTING MODELS ON IMPROVED DATASET")
print("="*50)

# Test original model
orig_preds, targets, orig_probs, orig_acc = evaluate_model(model, test_loader, "Original Model")

# Test rhythm model  
rhythm_preds, _, rhythm_probs, rhythm_acc = evaluate_model(rhythm_model, test_loader, "Rhythm Model")

# Convert to numpy arrays
orig_probs = np.array(orig_probs)
rhythm_probs = np.array(rhythm_probs)
targets = np.array(targets)

🧪 TESTING MODELS ON IMPROVED DATASET

📊 Original Model Results:
Accuracy: 70.59%

Classification Report:
              precision    recall  f1-score   support

        Bach      0.708     0.908     0.796       131
   Beethoven      0.697     0.484     0.571       128
      Chopin      0.746     0.765     0.755       119
      Mozart      0.672     0.667     0.669       132

    accuracy                          0.706       510
   macro avg      0.706     0.706     0.698       510
weighted avg      0.705     0.706     0.697       510


📊 Rhythm Model Results:
Accuracy: 73.14%

Classification Report:
              precision    recall  f1-score   support

        Bach      0.797     0.809     0.803       131
   Beethoven      0.633     0.727     0.676       128
      Chopin      0.847     0.790     0.817       119
      Mozart      0.672     0.606     0.637       132

    accuracy                          0.731       510
   macro avg      0.737     0.733     0.734       510
weighted avg  

In [25]:
def get_piano_roll_non_overlapping_segments(midi_path, fs=100, segment_duration=45.0):
    """
    Extract non-overlapping segments to avoid data leakage with LSTM
    """
    try:
        pm = pretty_midi.PrettyMIDI(midi_path)
        
        # Get the actual duration of the piece
        actual_duration = pm.get_end_time()
        
        # If piece is very short, skip it
        if actual_duration < 15.0:  # Less than 15 seconds
            return None
            
        segments = []
        segment_size = segment_duration
        
        # Calculate number of non-overlapping segments
        num_segments = int(actual_duration // segment_size)
        
        # If piece is shorter than one segment, use the whole piece (padded)
        if num_segments == 0:
            piano_roll = pm.get_piano_roll(fs=fs)
            target_length = int(segment_duration * fs)
            
            if piano_roll.shape[1] < target_length:
                pad_width = target_length - piano_roll.shape[1]
                piano_roll = np.pad(piano_roll, ((0,0),(0,pad_width)), mode='constant')
            else:
                piano_roll = piano_roll[:, :target_length]
            
            segments.append(piano_roll)
            return segments
        
        # Extract non-overlapping segments
        for i in range(num_segments):
            start_time = i * segment_size
            end_time = start_time + segment_size
            
            # Create a copy and trim
            pm_segment = pretty_midi.PrettyMIDI()
            for instrument in pm.instruments:
                new_instrument = pretty_midi.Instrument(
                    program=instrument.program,
                    is_drum=instrument.is_drum,
                    name=instrument.name
                )
                
                for note in instrument.notes:
                    if start_time <= note.start < end_time:
                        new_note = pretty_midi.Note(
                            velocity=note.velocity,
                            pitch=note.pitch,
                            start=note.start - start_time,
                            end=min(note.end - start_time, segment_duration)
                        )
                        new_instrument.notes.append(new_note)
                
                if new_instrument.notes:
                    pm_segment.instruments.append(new_instrument)
            
            if pm_segment.instruments:
                piano_roll = pm_segment.get_piano_roll(fs=fs)
                target_length = int(segment_duration * fs)
                
                if piano_roll.shape[1] > target_length:
                    piano_roll = piano_roll[:, :target_length]
                else:
                    pad_width = target_length - piano_roll.shape[1]
                    piano_roll = np.pad(piano_roll, ((0,0),(0,pad_width)), mode='constant')
                
                segments.append(piano_roll)
        
        return segments if segments else None
            
    except Exception as e:
        print(f"Error processing {midi_path}: {e}")
        return None

print("✅ Non-overlapping segmentation function defined!")

✅ Non-overlapping segmentation function defined!


In [26]:
def load_dataset_non_overlapping_full(extract_path, target_composers, segment_duration=45.0):
    """
    Load FULL dataset with non-overlapping segments - using ALL available files
    """
    print("🎵 LOADING FULL DATASET WITH NON-OVERLAPPING SEGMENTS...")
    print("Benefits for LSTM training:")
    print(f"• Segment duration: {segment_duration}s (no overlap)")
    print("• Using ALL available files (no limits)")
    print("• No data leakage between train/test")
    print("• Cleaner temporal boundaries")
    print("• Will address class imbalance later during training")
    
    composer_to_idx = {c: i for i, c in enumerate(target_composers)}
    all_data = []
    all_labels = []
    all_features = []
    
    total_files_processed = 0
    total_segments_created = 0
    
    for composer in target_composers:
        print(f"\n--- Processing {composer} ---")
        composer_dir = os.path.join(extract_path, composer)
        
        if not os.path.isdir(composer_dir):
            print(f"Directory not found: {composer_dir}")
            continue
            
        composer_data = []
        composer_labels = []
        composer_features = []
        files_processed = 0
        segments_created = 0
        
        # Get ALL MIDI files - no limit
        midi_files = [f for f in os.listdir(composer_dir) 
                     if f.lower().endswith(('.mid', '.midi'))]
        
        print(f"  Found {len(midi_files)} MIDI files for {composer}")
        
        for file in midi_files:
            midi_path = os.path.join(composer_dir, file)
            
            try:
                # Use non-overlapping segmentation
                segments = get_piano_roll_non_overlapping_segments(
                    midi_path, 
                    segment_duration=segment_duration
                )
                
                if segments is None:
                    continue
                    
                for segment in segments:
                    # Normalize the segment
                    normalized_segment = normalize_piano_roll(segment)
                    
                    # Extract musical features
                    features = extract_musical_features(normalized_segment)
                    
                    # Quality check: skip if too sparse
                    note_density = features['avg_notes_per_time']
                    if note_density < 0.1:  # Very sparse, likely poor quality
                        continue
                    
                    composer_data.append(normalized_segment)
                    composer_labels.append(composer_to_idx[composer])
                    composer_features.append(features)
                    segments_created += 1
                
                files_processed += 1
                
                # Progress update every 50 files for full dataset
                if files_processed % 50 == 0:
                    print(f"  Processed {files_processed}/{len(midi_files)} files, created {segments_created} segments...")
                    
            except Exception as e:
                print(f"  Error processing {file}: {e}")
                continue
        
        print(f"✅ {composer}: {files_processed}/{len(midi_files)} files → {segments_created} segments")
        
        if composer_data:
            composer_data = np.array(composer_data)
            composer_labels = np.array(composer_labels)
            
            all_data.append(composer_data)
            all_labels.append(composer_labels) 
            all_features.extend(composer_features)
            
            print(f"  Final data shape: {composer_data.shape}")
        
        total_files_processed += files_processed
        total_segments_created += segments_created
    
    # Combine all data
    if all_data:
        data = np.concatenate(all_data, axis=0)
        labels = np.concatenate(all_labels, axis=0)
        
        print(f"\n🎯 FINAL FULL NON-OVERLAPPING DATASET:")
        print(f"Total files processed: {total_files_processed}")
        print(f"Total samples: {len(data)}")
        print(f"Data shape: {data.shape}")
        print(f"Label distribution: {np.bincount(labels)}")
        
        # Show class distribution percentages
        for i, composer in enumerate(target_composers):
            count = np.sum(labels == i)
            percentage = (count / len(labels)) * 100
            print(f"  {composer}: {count} samples ({percentage:.1f}%)")
        
        return data, labels, all_features
    else:
        print("❌ No data loaded!")
        return None, None, None

# Load the FULL dataset with non-overlapping segments
print("🚀 Starting FULL non-overlapping segment data loading...")
full_data, full_labels, full_features = load_dataset_non_overlapping_full(
    extract_path, 
    TARGET_COMPOSERS,
    segment_duration=45.0  # 45-second segments, no overlap, ALL files
)

🚀 Starting FULL non-overlapping segment data loading...
🎵 LOADING FULL DATASET WITH NON-OVERLAPPING SEGMENTS...
Benefits for LSTM training:
• Segment duration: 45.0s (no overlap)
• Using ALL available files (no limits)
• No data leakage between train/test
• Cleaner temporal boundaries
• Will address class imbalance later during training

--- Processing Bach ---
  Found 131 MIDI files for Bach
  Processed 50/131 files, created 355 segments...
  Processed 100/131 files, created 733 segments...
✅ Bach: 131/131 files → 977 segments
  Final data shape: (977, 128, 4500)

--- Processing Beethoven ---
  Found 134 MIDI files for Beethoven
  Processed 50/134 files, created 399 segments...
  Processed 100/134 files, created 765 segments...
Error processing data/kaggle/midiclassics/Beethoven/Anhang 14-3.mid: Could not decode key with 3 flats and mode 255
✅ Beethoven: 133/134 files → 987 segments
  Final data shape: (987, 128, 4500)

--- Processing Chopin ---
  Found 136 MIDI files for Chopin
  Pro

In [32]:
# =====================================================
# AGGRESSIVE CNN-LSTM-TRANSFORMER FOR A100 40GB
# =====================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    """Enhanced positional encoding for transformer"""
    def __init__(self, d_model, max_len=10000):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class AggressiveCNN_LSTM_Transformer(nn.Module):
    def __init__(self, num_classes=4, lstm_hidden=512, transformer_dim=1024, num_heads=16, num_layers=8):
        super(AggressiveCNN_LSTM_Transformer, self).__init__()
        
        print("🚀 Building AGGRESSIVE CNN-LSTM-Transformer for A100 40GB...")
        print(f"• Deep CNN feature extraction (6 blocks)")
        print(f"• Large LSTM temporal modeling (hidden: {lstm_hidden})")
        print(f"• Deep Transformer self-attention (dim: {transformer_dim}, heads: {num_heads}, layers: {num_layers})")
        print(f"• Multi-scale feature fusion")
        print(f"• Advanced attention mechanisms")
        
        # ==========================================
        # DEEP CNN BACKBONE - 6 BLOCKS
        # ==========================================
        
        # Block 1: Initial feature extraction
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout2d(0.1),
            nn.MaxPool2d(kernel_size=(2, 2))  # 128x64 -> 64x32
        )
        
        # Block 2: Deeper features
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout2d(0.15),
            nn.MaxPool2d(kernel_size=(2, 2))  # 64x32 -> 32x16
        )
        
        # Block 3: More complex patterns
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.MaxPool2d(kernel_size=(2, 2))  # 32x16 -> 16x8
        )
        
        # Block 4: High-level features
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Dropout2d(0.25),
            nn.MaxPool2d(kernel_size=(2, 2))  # 16x8 -> 8x4
        )
        
        # Block 5: Abstract features
        self.conv5 = nn.Sequential(
            nn.Conv2d(512, 768, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(768),
            nn.ReLU(),
            nn.Conv2d(768, 768, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(768),
            nn.ReLU(),
            nn.Dropout2d(0.3),
            nn.MaxPool2d(kernel_size=(2, 2))  # 8x4 -> 4x2
        )
        
        # Block 6: Final feature extraction
        self.conv6 = nn.Sequential(
            nn.Conv2d(768, 1024, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.Conv2d(1024, 1024, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.Dropout2d(0.35),
            nn.AdaptiveAvgPool2d((2, 1))  # Ensure consistent output: 2x1
        )
        
        # ==========================================
        # LARGE BIDIRECTIONAL LSTM
        # ==========================================
        self.feature_size = 1024 * 2  # 1024 channels * 2x1 spatial
        self.lstm_hidden = lstm_hidden
        
        # Multi-layer LSTM with larger capacity
        self.lstm = nn.LSTM(
            input_size=self.feature_size,
            hidden_size=lstm_hidden,
            num_layers=4,  # Deeper LSTM
            batch_first=True,
            dropout=0.3,
            bidirectional=True
        )
        
        # Additional LSTM for temporal refinement
        self.lstm_refine = nn.LSTM(
            input_size=lstm_hidden * 2,
            hidden_size=lstm_hidden // 2,
            num_layers=2,
            batch_first=True,
            dropout=0.2,
            bidirectional=True
        )
        
        # ==========================================
        # DEEP TRANSFORMER ENCODER
        # ==========================================
        self.transformer_dim = transformer_dim
        
        # Project LSTM output to transformer dimension
        self.lstm_to_transformer = nn.Sequential(
            nn.Linear(lstm_hidden, transformer_dim),
            nn.LayerNorm(transformer_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        
        # Positional encoding
        self.pos_encoding = PositionalEncoding(transformer_dim, max_len=10000)
        
        # Deep transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=transformer_dim,
            nhead=num_heads,
            dim_feedforward=transformer_dim * 4,  # Large feedforward
            dropout=0.1,
            activation='gelu',  # GELU activation for better performance
            batch_first=True,
            norm_first=True  # Pre-norm for better training stability
        )
        
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, 
            num_layers=num_layers,
            norm=nn.LayerNorm(transformer_dim)
        )
        
        # ==========================================
        # MULTI-SCALE ATTENTION & FUSION
        # ==========================================
        
        # Multi-scale attention heads
        self.global_attention = nn.MultiheadAttention(
            embed_dim=transformer_dim,
            num_heads=num_heads,
            dropout=0.1,
            batch_first=True
        )
        
        self.local_attention = nn.MultiheadAttention(
            embed_dim=transformer_dim,
            num_heads=num_heads // 2,
            dropout=0.1,
            batch_first=True
        )
        
        # Cross-attention between global and local features
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=transformer_dim,
            num_heads=num_heads // 2,
            dropout=0.1,
            batch_first=True
        )
        
        # Feature fusion
        self.feature_fusion = nn.Sequential(
            nn.Linear(transformer_dim * 3, transformer_dim),
            nn.LayerNorm(transformer_dim),
            nn.GELU(),
            nn.Dropout(0.2)
        )
        
        # ==========================================
        # ADVANCED CLASSIFICATION HEAD
        # ==========================================
        
        # Hierarchical classification with multiple paths
        self.classifier = nn.Sequential(
            nn.LayerNorm(transformer_dim),
            nn.Linear(transformer_dim, 2048),
            nn.GELU(),
            nn.Dropout(0.5),
            
            nn.Linear(2048, 1024),
            nn.GELU(),
            nn.Dropout(0.4),
            
            nn.Linear(1024, 512),
            nn.GELU(),
            nn.Dropout(0.3),
            
            nn.Linear(512, 256),
            nn.GELU(),
            nn.Dropout(0.2),
            
            nn.Linear(256, num_classes)
        )
        
        # Auxiliary classifier for regularization
        self.aux_classifier = nn.Sequential(
            nn.Linear(lstm_hidden, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
        
        print("✅ AGGRESSIVE CNN-LSTM-Transformer architecture built!")
        
    def forward(self, x):
        batch_size = x.size(0)
        
        # ==========================================
        # DEEP CNN FEATURE EXTRACTION
        # ==========================================
        x = self.conv1(x)      # (batch, 64, 64, T/2)
        x = self.conv2(x)      # (batch, 128, 32, T/4)
        x = self.conv3(x)      # (batch, 256, 16, T/8)
        x = self.conv4(x)      # (batch, 512, 8, T/16)
        x = self.conv5(x)      # (batch, 768, 4, T/32)
        x = self.conv6(x)      # (batch, 1024, 2, T/64)
        
        # Reshape for LSTM: (batch, time_steps, features)
        x = x.permute(0, 3, 1, 2)  # (batch, T/64, 1024, 2)
        x = x.contiguous().view(batch_size, x.size(1), -1)  # (batch, T/64, 1024*2)
        
        # ==========================================
        # LARGE LSTM PROCESSING
        # ==========================================
        lstm_out, _ = self.lstm(x)  # (batch, T/64, lstm_hidden*2)
        lstm_refined, _ = self.lstm_refine(lstm_out)  # (batch, T/64, lstm_hidden)
        
        # Auxiliary classification from LSTM features (for regularization)
        lstm_pooled = torch.mean(lstm_refined, dim=1)
        aux_output = self.aux_classifier(lstm_pooled)
        
        # ==========================================
        # DEEP TRANSFORMER PROCESSING
        # ==========================================
        
        # Project to transformer dimension
        transformer_input = self.lstm_to_transformer(lstm_refined)  # (batch, T/64, transformer_dim)
        
        # Add positional encoding
        transformer_input = transformer_input.transpose(0, 1)  # (T/64, batch, transformer_dim)
        transformer_input = self.pos_encoding(transformer_input)
        transformer_input = transformer_input.transpose(0, 1)  # (batch, T/64, transformer_dim)
        
        # Deep transformer encoding
        transformer_out = self.transformer_encoder(transformer_input)  # (batch, T/64, transformer_dim)
        
        # ==========================================
        # MULTI-SCALE ATTENTION & FUSION
        # ==========================================
        
        # Global attention (full sequence)
        global_attended, _ = self.global_attention(
            transformer_out, transformer_out, transformer_out
        )
        
        # Local attention (sliding window - simulate by chunking)
        seq_len = transformer_out.size(1)
        if seq_len > 16:
            # Use overlapping windows
            local_features = []
            window_size = min(16, seq_len)
            for i in range(0, max(1, seq_len - window_size + 1), window_size // 2):
                end_idx = min(i + window_size, seq_len)
                window = transformer_out[:, i:end_idx, :]
                local_att, _ = self.local_attention(window, window, window)
                local_features.append(torch.mean(local_att, dim=1, keepdim=True))
            local_attended = torch.cat(local_features, dim=1)
        else:
            local_attended, _ = self.local_attention(
                transformer_out, transformer_out, transformer_out
            )
        
        # Cross attention between global and local
        cross_attended, _ = self.cross_attention(
            global_attended, local_attended, local_attended
        )
        
        # Fusion of multi-scale features
        # Pool to same size for concatenation
        global_pooled = torch.mean(global_attended, dim=1)
        local_pooled = torch.mean(local_attended, dim=1)
        cross_pooled = torch.mean(cross_attended, dim=1)
        
        fused_features = torch.cat([global_pooled, local_pooled, cross_pooled], dim=1)
        final_features = self.feature_fusion(fused_features)
        
        # ==========================================
        # CLASSIFICATION
        # ==========================================
        main_output = self.classifier(final_features)
        
        return main_output, aux_output

# Create the aggressive model
print("🚀 Creating AGGRESSIVE model for A100 40GB...")
aggressive_model = AggressiveCNN_LSTM_Transformer(
    num_classes=4,
    lstm_hidden=512,        # Doubled from 256
    transformer_dim=1024,   # Doubled from 512  
    num_heads=16,          # Doubled from 8
    num_layers=8           # Doubled from 4
).to(device)

# Test with dummy input
test_input = torch.randn(4, 1, 128, 4500).to(device)  # Larger batch
print(f"Input shape: {test_input.shape}")

with torch.no_grad():
    main_out, aux_out = aggressive_model(test_input)
    print(f"Main output shape: {main_out.shape}")
    print(f"Auxiliary output shape: {aux_out.shape}")
    print(f"✅ Aggressive model forward pass successful!")

# Count parameters
total_params = sum(p.numel() for p in aggressive_model.parameters())
trainable_params = sum(p.numel() for p in aggressive_model.parameters() if p.requires_grad)

print(f"\n📊 AGGRESSIVE Model Statistics:")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: ~{total_params * 4 / 1024 / 1024:.1f} MB")
print(f"Estimated GPU memory (training): ~{total_params * 16 / 1024 / 1024:.1f} MB")

🚀 Creating AGGRESSIVE model for A100 40GB...
🚀 Building AGGRESSIVE CNN-LSTM-Transformer for A100 40GB...
• Deep CNN feature extraction (6 blocks)
• Large LSTM temporal modeling (hidden: 512)
• Deep Transformer self-attention (dim: 1024, heads: 16, layers: 8)
• Multi-scale feature fusion
• Advanced attention mechanisms




✅ AGGRESSIVE CNN-LSTM-Transformer architecture built!
Input shape: torch.Size([4, 1, 128, 4500])
Main output shape: torch.Size([4, 4])
Auxiliary output shape: torch.Size([4, 4])
✅ Aggressive model forward pass successful!

📊 AGGRESSIVE Model Statistics:
Total parameters: 185,688,776
Trainable parameters: 185,688,776
Model size: ~708.3 MB
Estimated GPU memory (training): ~2833.4 MB


In [33]:
# =====================================================
# SEQUENCE-AWARE DATA LOADING WITH PIECE TRACKING
# =====================================================

def load_dataset_with_piece_tracking(extract_path, target_composers, segment_duration=45.0):
    """
    Load dataset while tracking which segments belong to the same piece
    This enables sequence-aware training that distinguishes:
    - Consecutive segments from SAME piece vs DIFFERENT pieces
    - Same composer vs different composer relationships
    """
    print("🎵 LOADING DATASET WITH PIECE TRACKING...")
    print("Key improvements:")
    print(f"• Track piece identity for each segment")
    print(f"• Enable consecutive segment modeling")
    print(f"• Support contrastive learning approaches")
    print(f"• Segment duration: {segment_duration}s")
    
    composer_to_idx = {c: i for i, c in enumerate(target_composers)}
    all_data = []
    all_labels = []
    all_piece_ids = []  # NEW: Track which piece each segment comes from
    all_piece_names = []  # NEW: Track piece names for analysis
    
    piece_id_counter = 0
    total_files_processed = 0
    total_segments_created = 0
    
    for composer in target_composers:
        print(f"\n--- Processing {composer} ---")
        composer_dir = os.path.join(extract_path, composer)
        
        if not os.path.isdir(composer_dir):
            print(f"Directory not found: {composer_dir}")
            continue
            
        midi_files = [f for f in os.listdir(composer_dir) 
                     if f.lower().endswith(('.mid', '.midi'))]
        
        print(f"  Found {len(midi_files)} MIDI files for {composer}")
        files_processed = 0
        segments_created = 0
        
        for file in midi_files:
            midi_path = os.path.join(composer_dir, file)
            
            try:
                # Use non-overlapping segmentation
                segments = get_piano_roll_non_overlapping_segments(
                    midi_path, segment_duration=segment_duration
                )
                
                if segments is None or len(segments) == 0:
                    continue
                
                # All segments from this file get the same piece_id
                current_piece_id = piece_id_counter
                piece_name = f"{composer}_{file}"
                piece_id_counter += 1
                
                valid_segments_from_piece = 0
                
                for segment_idx, segment in enumerate(segments):
                    # Normalize the segment
                    normalized_segment = normalize_piano_roll(segment)
                    
                    # Extract musical features
                    features = extract_musical_features(normalized_segment)
                    
                    # Quality check: skip if too sparse
                    if features['avg_notes_per_time'] < 0.1:
                        continue
                    
                    all_data.append(normalized_segment)
                    all_labels.append(composer_to_idx[composer])
                    all_piece_ids.append(current_piece_id)
                    all_piece_names.append(f"{piece_name}_seg{segment_idx}")
                    
                    valid_segments_from_piece += 1
                    segments_created += 1
                
                if valid_segments_from_piece > 0:
                    files_processed += 1
                
                # Progress update
                if files_processed % 50 == 0:
                    print(f"  Processed {files_processed}/{len(midi_files)} files, created {segments_created} segments...")
                    
            except Exception as e:
                print(f"  Error processing {file}: {e}")
                continue
        
        print(f"✅ {composer}: {files_processed}/{len(midi_files)} files → {segments_created} segments")
        total_files_processed += files_processed
        total_segments_created += segments_created
    
    # Convert to numpy arrays
    data = np.array(all_data)
    labels = np.array(all_labels)
    piece_ids = np.array(all_piece_ids)
    
    print(f"\n🎯 FINAL PIECE-TRACKED DATASET:")
    print(f"Total files processed: {total_files_processed}")
    print(f"Total segments: {len(data)}")
    print(f"Total unique pieces: {len(np.unique(piece_ids))}")
    print(f"Data shape: {data.shape}")
    print(f"Label distribution: {np.bincount(labels)}")
    
    # Analyze segments per piece
    segments_per_piece = []
    for piece_id in np.unique(piece_ids):
        count = np.sum(piece_ids == piece_id)
        segments_per_piece.append(count)
    
    print(f"Segments per piece - Mean: {np.mean(segments_per_piece):.1f}, "
          f"Min: {np.min(segments_per_piece)}, Max: {np.max(segments_per_piece)}")
    
    # Show class distribution percentages
    for i, composer in enumerate(target_composers):
        count = np.sum(labels == i)
        percentage = (count / len(labels)) * 100
        print(f"  {composer}: {count} segments ({percentage:.1f}%)")
    
    return data, labels, piece_ids, all_piece_names

# Load the dataset with piece tracking
print("🚀 Starting piece-tracked data loading...")
tracked_data, tracked_labels, tracked_piece_ids, piece_names = load_dataset_with_piece_tracking(
    extract_path, 
    TARGET_COMPOSERS,
    segment_duration=45.0
)

🚀 Starting piece-tracked data loading...
🎵 LOADING DATASET WITH PIECE TRACKING...
Key improvements:
• Track piece identity for each segment
• Enable consecutive segment modeling
• Support contrastive learning approaches
• Segment duration: 45.0s

--- Processing Bach ---
  Found 131 MIDI files for Bach




  Processed 50/131 files, created 355 segments...
  Processed 100/131 files, created 733 segments...
✅ Bach: 131/131 files → 977 segments

--- Processing Beethoven ---
  Found 134 MIDI files for Beethoven
  Processed 50/134 files, created 399 segments...
  Processed 100/134 files, created 765 segments...
Error processing data/kaggle/midiclassics/Beethoven/Anhang 14-3.mid: Could not decode key with 3 flats and mode 255
✅ Beethoven: 133/134 files → 987 segments

--- Processing Chopin ---
  Found 136 MIDI files for Chopin
  Processed 50/136 files, created 235 segments...
  Processed 100/136 files, created 467 segments...
✅ Chopin: 136/136 files → 610 segments

--- Processing Mozart ---
  Found 90 MIDI files for Mozart
  Processed 50/90 files, created 263 segments...
✅ Mozart: 90/90 files → 524 segments

🎯 FINAL PIECE-TRACKED DATASET:
Total files processed: 490
Total segments: 3098
Total unique pieces: 490
Data shape: (3098, 128, 4500)
Label distribution: [977 987 610 524]
Segments per pie

In [34]:
# =====================================================
# CLASS WEIGHTS IMPLEMENTATION
# =====================================================

from sklearn.utils.class_weight import compute_class_weight
import torch.nn as nn

# Compute class weights for our imbalanced dataset
def compute_class_weights(labels, target_composers):
    """
    Compute class weights to handle imbalance
    """
    print("⚖️ COMPUTING CLASS WEIGHTS...")
    
    # Get unique labels and compute balanced weights
    unique_labels = np.unique(labels)
    class_weights = compute_class_weight(
        'balanced', 
        classes=unique_labels, 
        y=labels
    )
    
    print(f"Original distribution:")
    for i, composer in enumerate(target_composers):
        count = np.sum(labels == i)
        percentage = (count / len(labels)) * 100
        print(f"  {composer}: {count} samples ({percentage:.1f}%)")
    
    print(f"\nComputed class weights:")
    for i, (composer, weight) in enumerate(zip(target_composers, class_weights)):
        print(f"  {composer}: {weight:.3f}")
    
    # Convert to PyTorch tensor
    class_weights_tensor = torch.FloatTensor(class_weights).to(device)
    
    print(f"\n✅ Class weights ready for CrossEntropyLoss")
    return class_weights_tensor

# Compute class weights for our tracked dataset
class_weights = compute_class_weights(tracked_labels, TARGET_COMPOSERS)

⚖️ COMPUTING CLASS WEIGHTS...
Original distribution:
  Bach: 977 samples (31.5%)
  Beethoven: 987 samples (31.9%)
  Chopin: 610 samples (19.7%)
  Mozart: 524 samples (16.9%)

Computed class weights:
  Bach: 0.793
  Beethoven: 0.785
  Chopin: 1.270
  Mozart: 1.478

✅ Class weights ready for CrossEntropyLoss


In [35]:
# =====================================================
# SEQUENCE-AWARE DATASET FOR CONSECUTIVE SEGMENTS
# =====================================================

from torch.utils.data import Dataset
import random

class SequenceAwareDataset(Dataset):
    """
    Dataset that creates sequences of consecutive segments from same pieces
    This enables the model to learn temporal relationships within compositions
    """
    def __init__(self, data, labels, piece_ids, sequence_length=2, include_singles=True):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
        self.piece_ids = torch.tensor(piece_ids, dtype=torch.long)
        self.sequence_length = sequence_length
        self.include_singles = include_singles
        
        # Group segments by piece
        self.piece_segments = {}
        for idx, piece_id in enumerate(piece_ids):
            piece_id = int(piece_id)
            if piece_id not in self.piece_segments:
                self.piece_segments[piece_id] = []
            self.piece_segments[piece_id].append(idx)
        
        # Create sequence indices
        self.sequence_indices = self._create_sequence_indices()
        
        print(f"📊 SEQUENCE DATASET CREATED:")
        print(f"• Sequence length: {sequence_length}")
        print(f"• Total pieces: {len(self.piece_segments)}")
        print(f"• Total sequences: {len(self.sequence_indices)}")
        print(f"• Include single segments: {include_singles}")
    
    def _create_sequence_indices(self):
        """Create indices for all possible sequences"""
        sequences = []
        
        # Add consecutive sequences from same pieces
        for piece_id, segment_indices in self.piece_segments.items():
            if len(segment_indices) >= self.sequence_length:
                # Create all possible consecutive sequences from this piece
                for start_idx in range(len(segment_indices) - self.sequence_length + 1):
                    seq_indices = segment_indices[start_idx:start_idx + self.sequence_length]
                    sequences.append({
                        'type': 'sequence',
                        'indices': seq_indices,
                        'piece_id': piece_id,
                        'label': int(self.labels[seq_indices[0]])  # All should have same label
                    })
        
        # Optionally add single segments
        if self.include_singles:
            for piece_id, segment_indices in self.piece_segments.items():
                for idx in segment_indices:
                    sequences.append({
                        'type': 'single',
                        'indices': [idx],
                        'piece_id': piece_id,
                        'label': int(self.labels[idx])
                    })
        
        return sequences
    
    def __len__(self):
        return len(self.sequence_indices)
    
    def __getitem__(self, idx):
        sequence_info = self.sequence_indices[idx]
        indices = sequence_info['indices']
        
        if len(indices) == 1:
            # Single segment
            segment = self.data[indices[0]].unsqueeze(0)  # Add channel dim
            return segment, sequence_info['label'], sequence_info['piece_id'], 'single'
        else:
            # Multiple segments - stack them
            segments = []
            for seg_idx in indices:
                segments.append(self.data[seg_idx].unsqueeze(0))  # Add channel dim
            
            # Stack segments: (sequence_length, 1, 128, T)
            sequence = torch.stack(segments, dim=0)
            
            return sequence, sequence_info['label'], sequence_info['piece_id'], 'sequence'

# Create sequence-aware dataset
print("🔗 Creating sequence-aware dataset...")
sequence_dataset = SequenceAwareDataset(
    tracked_data, 
    tracked_labels, 
    tracked_piece_ids,
    sequence_length=2,  # Start with pairs
    include_singles=True
)

🔗 Creating sequence-aware dataset...
📊 SEQUENCE DATASET CREATED:
• Sequence length: 2
• Total pieces: 490
• Total sequences: 5706
• Include single segments: True


In [36]:
# =====================================================
# TRAINING SETUP WITH CLASS WEIGHTS
# =====================================================

import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from sklearn.model_selection import train_test_split

def create_balanced_training_setup(data, labels, piece_ids, class_weights, test_size=0.2, val_size=0.15):
    """
    Create training setup with class weights for imbalance handling
    """
    print("🎯 CREATING BALANCED TRAINING SETUP...")
    print(f"Using class weights: {class_weights}")
    
    # Split data at the piece level to avoid data leakage
    unique_pieces = np.unique(piece_ids)
    piece_labels = np.array([labels[piece_ids == pid][0] for pid in unique_pieces])
    
    # Split pieces into train/val/test
    train_pieces, test_pieces, _, _ = train_test_split(
        unique_pieces, piece_labels, 
        test_size=test_size, 
        stratify=piece_labels, 
        random_state=42
    )
    
    train_pieces, val_pieces, _, _ = train_test_split(
        train_pieces, piece_labels[np.isin(unique_pieces, train_pieces)], 
        test_size=val_size/(1-test_size), 
        stratify=piece_labels[np.isin(unique_pieces, train_pieces)], 
        random_state=42
    )
    
    # Create masks for train/val/test
    train_mask = np.isin(piece_ids, train_pieces)
    val_mask = np.isin(piece_ids, val_pieces)
    test_mask = np.isin(piece_ids, test_pieces)
    
    print(f"📊 DATA SPLIT:")
    print(f"Train pieces: {len(train_pieces)} | segments: {np.sum(train_mask)}")
    print(f"Val pieces:   {len(val_pieces)} | segments: {np.sum(val_mask)}")
    print(f"Test pieces:  {len(test_pieces)} | segments: {np.sum(test_mask)}")
    
    # Show class distribution per split
    for split_name, mask in [("Train", train_mask), ("Val", val_mask), ("Test", test_mask)]:
        split_labels = labels[mask]
        print(f"\n{split_name} distribution:")
        for i, composer in enumerate(TARGET_COMPOSERS):
            count = np.sum(split_labels == i)
            percentage = (count / len(split_labels)) * 100 if len(split_labels) > 0 else 0
            print(f"  {composer}: {count} ({percentage:.1f}%)")
    
    # Create datasets
    train_dataset = PianoRollDataset(data[train_mask], labels[train_mask])
    val_dataset = PianoRollDataset(data[val_mask], labels[val_mask])
    test_dataset = PianoRollDataset(data[test_mask], labels[test_mask])
    
    return train_dataset, val_dataset, test_dataset, train_mask, val_mask, test_mask

# Create the balanced training setup
train_dataset, val_dataset, test_dataset, train_mask, val_mask, test_mask = create_balanced_training_setup(
    tracked_data, tracked_labels, tracked_piece_ids, class_weights
)

# Create weighted loss function
weighted_criterion = nn.CrossEntropyLoss(weight=class_weights)
print(f"\n✅ Weighted loss function created with class weights")

# Create data loaders
batch_size = 16  # Smaller batch size for stability
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

print(f"\n📚 DATA LOADERS CREATED:")
print(f"• Batch size: {batch_size}")
print(f"• Train batches: {len(train_loader)}")
print(f"• Val batches: {len(val_loader)}")
print(f"• Test batches: {len(test_loader)}")

🎯 CREATING BALANCED TRAINING SETUP...
Using class weights: tensor([0.7927, 0.7847, 1.2697, 1.4781], device='mps:0')
📊 DATA SPLIT:
Train pieces: 318 | segments: 2015
Val pieces:   74 | segments: 471
Test pieces:  98 | segments: 612

Train distribution:
  Bach: 602 (29.9%)
  Beethoven: 643 (31.9%)
  Chopin: 433 (21.5%)
  Mozart: 337 (16.7%)

Val distribution:
  Bach: 171 (36.3%)
  Beethoven: 175 (37.2%)
  Chopin: 68 (14.4%)
  Mozart: 57 (12.1%)

Test distribution:
  Bach: 204 (33.3%)
  Beethoven: 169 (27.6%)
  Chopin: 109 (17.8%)
  Mozart: 130 (21.2%)

✅ Weighted loss function created with class weights

📚 DATA LOADERS CREATED:
• Batch size: 16
• Train batches: 126
• Val batches: 30
• Test batches: 39


In [None]:
# =====================================================
# ADVANCED TRAINING LOOP FOR AGGRESSIVE MODEL
# =====================================================

import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ReduceLROnPlateau
import time
from collections import defaultdict
import matplotlib.pyplot as plt

def train_aggressive_model(model, train_loader, val_loader, class_weights, 
                          epochs=50, initial_lr=1e-3, save_path='aggressive_model.pth'):
    """
    Advanced training with:
    - Auxiliary loss for regularization
    - Cosine annealing with warm restarts
    - Mixed precision training
    - Advanced metrics tracking
    - Early stopping with patience
    """
    print("🚀 STARTING AGGRESSIVE MODEL TRAINING...")
    print(f"• Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"• Training samples: {len(train_loader.dataset)}")
    print(f"• Validation samples: {len(val_loader.dataset)}")
    print(f"• Epochs: {epochs}")
    print(f"• Initial learning rate: {initial_lr}")
    print(f"• Using mixed precision: True")
    print(f"• Class weights: {class_weights}")
    
    # Setup optimizers and schedulers
    optimizer = optim.AdamW(
        model.parameters(), 
        lr=initial_lr, 
        weight_decay=1e-4,
        betas=(0.9, 0.999),
        eps=1e-8
    )
    
    # Cosine annealing with warm restarts
    scheduler = CosineAnnealingWarmRestarts(
        optimizer, 
        T_0=10,  # Restart every 10 epochs
        T_mult=2,  # Double the period after each restart
        eta_min=1e-6
    )
    
    # Mixed precision scaler
    scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None
    
    # Loss functions
    main_criterion = nn.CrossEntropyLoss(weight=class_weights)
    aux_criterion = nn.CrossEntropyLoss(weight=class_weights)
    
    # Training tracking
    history = {
        'train_loss': [], 'train_acc': [], 'train_main_loss': [], 'train_aux_loss': [],
        'val_loss': [], 'val_acc': [], 'val_main_loss': [], 'val_aux_loss': [],
        'lr': [], 'epoch_time': []
    }
    
    best_val_acc = 0.0
    patience_counter = 0
    patience = 15  # Early stopping patience
    
    print("\n🎯 Training Configuration:")
    print(f"• Optimizer: AdamW (lr={initial_lr}, weight_decay=1e-4)")
    print(f"• Scheduler: CosineAnnealingWarmRestarts (T_0=10, T_mult=2)")
    print(f"• Main loss weight: 0.7, Auxiliary loss weight: 0.3")
    print(f"• Early stopping patience: {patience}")
    print(f"• Mixed precision: {'CUDA' if scaler else 'Disabled'}")
    
    model.to(device)
    
    for epoch in range(epochs):
        epoch_start_time = time.time()
        
        # ==========================================
        # TRAINING PHASE
        # ==========================================
        model.train()
        train_metrics = defaultdict(float)
        train_correct = 0
        train_total = 0
        
        print(f"\n📈 Epoch {epoch+1}/{epochs}")
        print("-" * 50)
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            
            # Mixed precision forward pass
            if scaler:
                with torch.cuda.amp.autocast():
                    main_output, aux_output = model(data)
                    main_loss = main_criterion(main_output, target)
                    aux_loss = aux_criterion(aux_output, target)
                    # Combine losses with weights
                    total_loss = 0.7 * main_loss + 0.3 * aux_loss
                
                # Mixed precision backward pass
                scaler.scale(total_loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                main_output, aux_output = model(data)
                main_loss = main_criterion(main_output, target)
                aux_loss = aux_criterion(aux_output, target)
                total_loss = 0.7 * main_loss + 0.3 * aux_loss
                
                total_loss.backward()
                optimizer.step()
            
            # Track metrics
            train_metrics['total_loss'] += total_loss.item()
            train_metrics['main_loss'] += main_loss.item()
            train_metrics['aux_loss'] += aux_loss.item()
            
            # Accuracy from main output
            _, predicted = torch.max(main_output.data, 1)
            train_total += target.size(0)
            train_correct += (predicted == target).sum().item()
            
            # Progress update
            if batch_idx % 20 == 0:
                current_lr = optimizer.param_groups[0]['lr']
                print(f"  Batch {batch_idx:3d}/{len(train_loader)} | "
                      f"Loss: {total_loss.item():.4f} | "
                      f"Acc: {100.*train_correct/train_total:.2f}% | "
                      f"LR: {current_lr:.2e}")
        
        # ==========================================
        # VALIDATION PHASE
        # ==========================================
        model.eval()
        val_metrics = defaultdict(float)
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                
                if scaler:
                    with torch.cuda.amp.autocast():
                        main_output, aux_output = model(data)
                        main_loss = main_criterion(main_output, target)
                        aux_loss = aux_criterion(aux_output, target)
                        total_loss = 0.7 * main_loss + 0.3 * aux_loss
                else:
                    main_output, aux_output = model(data)
                    main_loss = main_criterion(main_output, target)
                    aux_loss = aux_criterion(aux_output, target)
                    total_loss = 0.7 * main_loss + 0.3 * aux_loss
                
                val_metrics['total_loss'] += total_loss.item()
                val_metrics['main_loss'] += main_loss.item()
                val_metrics['aux_loss'] += aux_loss.item()
                
                _, predicted = torch.max(main_output.data, 1)
                val_total += target.size(0)
                val_correct += (predicted == target).sum().item()
        
        # Calculate epoch metrics
        train_loss = train_metrics['total_loss'] / len(train_loader)
        train_acc = 100. * train_correct / train_total
        val_loss = val_metrics['total_loss'] / len(val_loader)
        val_acc = 100. * val_correct / val_total
        
        # Update learning rate
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        
        # Record history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['train_main_loss'].append(train_metrics['main_loss'] / len(train_loader))
        history['train_aux_loss'].append(train_metrics['aux_loss'] / len(train_loader))
        
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['val_main_loss'].append(val_metrics['main_loss'] / len(val_loader))
        history['val_aux_loss'].append(val_metrics['aux_loss'] / len(val_loader))
        
        history['lr'].append(current_lr)
        
        epoch_time = time.time() - epoch_start_time
        history['epoch_time'].append(epoch_time)
        
        # Print epoch summary
        print(f"\n📊 Epoch {epoch+1} Summary:")
        print(f"  Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
        print(f"  Val:   Loss={val_loss:.4f}, Acc={val_acc:.2f}%")
        print(f"  LR: {current_lr:.2e}, Time: {epoch_time:.1f}s")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_acc': best_val_acc,
                'history': history
            }, save_path)
            print(f"  💾 New best model saved! Val Acc: {val_acc:.2f}%")
        else:
            patience_counter += 1
            print(f"  ⏳ Patience: {patience_counter}/{patience}")
        
        # Early stopping
        if patience_counter >= patience:
            print(f"\n🛑 Early stopping triggered! Best Val Acc: {best_val_acc:.2f}%")
            break
        
        # Memory cleanup
        if device.type == 'cuda':
            torch.cuda.empty_cache()
    
    print(f"\n✅ Training completed!")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")
    print(f"Model saved to: {save_path}")
    
    return history

# Create save directory
os.makedirs('saved_models', exist_ok=True)

# Start training the aggressive model
print("🚀 Starting training of AGGRESSIVE CNN-LSTM-Transformer...")
history = train_aggressive_model(
    aggressive_model,
    train_loader,
    val_loader, 
    class_weights,
    epochs=50,
    initial_lr=1e-3,
    save_path='saved_models/aggressive_cnn_lstm_transformer.pth'
)