In [1]:
%pip install pretty_midi kagglehub

Collecting pretty_midi
  Downloading pretty_midi-0.2.10.tar.gz (5.6 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/5.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m5.6/5.6 MB[0m [31m178.5 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m107.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mido>=1.1.16 (from pretty_midi)
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: pretty_midi
  Building wheel for pretty_midi (setup.py) ... [?25l[?25hdone
  Created wheel for pretty_midi: filename=pretty_midi-0.2.10-py3-none-any.whl size=5592286 sha256=6b7d68d4fa0805eb4

In [2]:
import os
import kagglehub
import zipfile
import shutil
import numpy as np
import torch
import torch.nn as nn
import os
import numpy as np
import pretty_midi
import torch
from torch.utils.data import Dataset

In [3]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
if not hasattr(np, 'int'):
    np.int = int

In [None]:
# Define your model class (must match the architecture used in final-project1)
class CNN_LSTM_Classifier(nn.Module):
    def __init__(self, num_classes=4, lstm_hidden=256):
        super(CNN_LSTM_Classifier, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.MaxPool2d(kernel_size=(2, 2))
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout2d(0.3),
            nn.MaxPool2d(kernel_size=(2, 2))
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout2d(0.4),
            nn.MaxPool2d(kernel_size=(2, 2))
        )
        self.feature_size = 128 * 16
        self.lstm = nn.LSTM(
            input_size=self.feature_size,
            hidden_size=lstm_hidden,
            num_layers=2,
            batch_first=True,
            dropout=0.3,
            bidirectional=True
        )
        self.attention = nn.MultiheadAttention(
            embed_dim=lstm_hidden * 2,
            num_heads=8,
            dropout=0.3,
            batch_first=True
        )
        self.classifier = nn.Sequential(
            nn.Linear(lstm_hidden * 2, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        batch_size = x.size(0)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.permute(0, 3, 1, 2)
        x = x.contiguous().view(batch_size, x.size(1), -1)
        lstm_out, _ = self.lstm(x)
        attn_out, _ = self.attention(lstm_out, lstm_out, lstm_out)
        pooled = torch.mean(attn_out, dim=1)
        output = self.classifier(pooled)
        return output


In [None]:
# Load the models
from IPython import get_ipython

if 'google.colab' in str(get_ipython()):
    print("Running in Google Colab.")
    original_path = os.path.join('/content/saved_models', 'original_cnn_lstm.pth')
    rhythm_path = os.path.join('/content/saved_models', 'rhythm_augmented_cnn_lstm.pth')
else:
    print("Not running in Google Colab.")
    original_path = os.path.join('saved_models', 'original_cnn_lstm.pth')
    rhythm_path = os.path.join('saved_models', 'rhythm_augmented_cnn_lstm.pth')

model = CNN_LSTM_Classifier(num_classes=4, lstm_hidden=256).to(device)
model.load_state_dict(torch.load(original_path, map_location=device))
print(f"✅ Loaded: {original_path}")

rhythm_model = CNN_LSTM_Classifier(num_classes=4, lstm_hidden=256).to(device)
rhythm_model.load_state_dict(torch.load(rhythm_path, map_location=device))
print(f"✅ Loaded: {rhythm_path}")

Running in Google Colab.
✅ Loaded: /content/saved_models/original_cnn_lstm.pth
✅ Loaded: /content/saved_models/rhythm_augmented_cnn_lstm.pth


In [4]:
TARGET_COMPOSERS = [
    'Bach',
    'Beethoven',
    'Chopin',
    'Mozart',
]

path = kagglehub.dataset_download("blanderbuss/midi-classic-music")

zip_path = os.path.join(path, 'midiclassics.zip')
extract_path = os.path.join('data', 'kaggle', 'midiclassics')


with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

# list files in extract_path that contain the target composers in name
for composer in TARGET_COMPOSERS:
    composer_files = [f for f in os.listdir(extract_path) if composer.lower() in f.lower()]

# Only keep directories that contain a target composer's name
for item in os.listdir(extract_path):
    item_path = os.path.join(extract_path, item)
    if not any(composer.lower() in item.lower() for composer in TARGET_COMPOSERS):
        if os.path.isfile(item_path):
            os.remove(item_path)
        elif os.path.isdir(item_path):
            shutil.rmtree(item_path)

# also delete "C.P.E.Bach" files. This was the son of J.S. Bach, and we want to keep only the main composers
for item in os.listdir(extract_path):
    if 'C.P.E.Bach' in item:
        item_path = os.path.join(extract_path, item)
        if os.path.isfile(item_path):
            os.remove(item_path)
        elif os.path.isdir(item_path):
            shutil.rmtree(item_path)

In [5]:
class PianoRollDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        # Add channel dimension for CNN: (1, 128, T)
        return self.data[idx].unsqueeze(0), self.labels[idx]

In [6]:
def get_piano_roll_improved(midi_path, fs=100, target_duration=45.0):
    """
    Improved MIDI to piano roll conversion with musical awareness
    """
    try:
        pm = pretty_midi.PrettyMIDI(midi_path)

        # Get the actual duration of the piece
        actual_duration = pm.get_end_time()

        # If piece is very short, skip it
        if actual_duration < 10.0:  # Less than 10 seconds
            return None

        # For long pieces, extract multiple segments
        if actual_duration > target_duration * 1.5:
            # Extract from different parts of the piece
            segments = []
            num_segments = min(3, int(actual_duration // target_duration))

            for i in range(num_segments):
                start_time = i * (actual_duration / num_segments)
                end_time = start_time + target_duration

                # Create a copy and trim
                pm_segment = pretty_midi.PrettyMIDI()
                for instrument in pm.instruments:
                    new_instrument = pretty_midi.Instrument(
                        program=instrument.program,
                        is_drum=instrument.is_drum,
                        name=instrument.name
                    )

                    for note in instrument.notes:
                        if start_time <= note.start < end_time:
                            new_note = pretty_midi.Note(
                                velocity=note.velocity,
                                pitch=note.pitch,
                                start=note.start - start_time,
                                end=min(note.end - start_time, target_duration)
                            )
                            new_instrument.notes.append(new_note)

                    if new_instrument.notes:
                        pm_segment.instruments.append(new_instrument)

                if pm_segment.instruments:
                    piano_roll = pm_segment.get_piano_roll(fs=fs)
                    target_length = int(target_duration * fs)

                    if piano_roll.shape[1] > target_length:
                        piano_roll = piano_roll[:, :target_length]
                    else:
                        pad_width = target_length - piano_roll.shape[1]
                        piano_roll = np.pad(piano_roll, ((0,0),(0,pad_width)), mode='constant')

                    segments.append(piano_roll)

            return segments

        else:
            # For normal length pieces, use the whole piece
            piano_roll = pm.get_piano_roll(fs=fs)
            target_length = int(target_duration * fs)

            if piano_roll.shape[1] > target_length:
                # Take from the middle rather than truncating end
                start_idx = (piano_roll.shape[1] - target_length) // 2
                piano_roll = piano_roll[:, start_idx:start_idx + target_length]
            else:
                pad_width = target_length - piano_roll.shape[1]
                piano_roll = np.pad(piano_roll, ((0,0),(0,pad_width)), mode='constant')

            return [piano_roll]

    except Exception as e:
        print(f"Error processing {midi_path}: {e}")
        return None


In [7]:
def normalize_piano_roll(piano_roll):
    """
    Apply musical normalization to piano roll
    """
    # 1. Velocity normalization (already 0-1 from pretty_midi)
    normalized = piano_roll.copy()

    # 2. Optional: Focus on active pitch range
    active_pitches = np.any(normalized > 0, axis=1)
    if np.any(active_pitches):
        first_active = np.argmax(active_pitches)
        last_active = len(active_pitches) - 1 - np.argmax(active_pitches[::-1])

        # Ensure we keep a reasonable range (at least 60 semitones = 5 octaves)
        min_range = 60
        current_range = last_active - first_active + 1

        if current_range < min_range:
            expand = (min_range - current_range) // 2
            first_active = max(0, first_active - expand)
            last_active = min(127, last_active + expand)

    return normalized


In [8]:
def extract_musical_features(piano_roll):
    """
    Extract features that capture musical style
    """
    features = {}

    # Temporal features
    note_density_timeline = np.sum(piano_roll > 0, axis=0)
    features['avg_notes_per_time'] = np.mean(note_density_timeline)
    features['note_density_variance'] = np.var(note_density_timeline)

    # Pitch features
    pitch_activity = np.sum(piano_roll > 0, axis=1)
    active_pitches = pitch_activity > 0
    if np.any(active_pitches):
        features['pitch_range'] = np.sum(active_pitches)
        features['lowest_pitch'] = np.argmax(active_pitches)
        features['highest_pitch'] = 127 - np.argmax(active_pitches[::-1])
    else:
        features['pitch_range'] = 0
        features['lowest_pitch'] = 60  # Middle C
        features['highest_pitch'] = 60

    # Rhythmic features
    onset_pattern = np.diff(note_density_timeline > 0).astype(int)
    features['onset_density'] = np.sum(onset_pattern == 1) / len(onset_pattern)

    return features

print("✅ Improved data processing functions defined!")
print("Key improvements:")
print("• Intelligent segment extraction for long pieces")
print("• Musical boundary awareness")
print("• Better normalization")
print("• Feature extraction for analysis")

✅ Improved data processing functions defined!
Key improvements:
• Intelligent segment extraction for long pieces
• Musical boundary awareness
• Better normalization
• Feature extraction for analysis


In [9]:
# =====================================================
# IMPROVED DATA LOADING WITH BETTER PROCESSING
# =====================================================

import numpy as np

def load_improved_dataset(extract_path, target_composers, target_duration=45.0, max_files_per_composer=None):
    """
    Load dataset with improved processing that addresses previous shortcomings
    """
    print("🎵 LOADING DATASET WITH IMPROVED PROCESSING...")
    print("Improvements over original:")
    print("• Intelligent segment extraction for long pieces")
    print("• Better handling of piece lengths")
    print("• Musical feature extraction")
    print("• Quality filtering")

    composer_to_idx = {c: i for i, c in enumerate(target_composers)}
    all_data = []
    all_labels = []
    all_features = []

    for composer in target_composers:
        print(f"\n--- Processing {composer} ---")
        composer_dir = os.path.join(extract_path, composer)

        if not os.path.isdir(composer_dir):
            print(f"Directory not found: {composer_dir}")
            continue

        composer_data = []
        composer_labels = []
        composer_features = []
        files_processed = 0
        segments_created = 0

        midi_files = [f for f in os.listdir(composer_dir)
                     if f.lower().endswith(('.mid', '.midi'))]

        if max_files_per_composer:
            midi_files = midi_files[:max_files_per_composer]

        for file in midi_files:
            midi_path = os.path.join(composer_dir, file)

            try:
                # Use improved processing
                segments = get_piano_roll_improved(midi_path, target_duration=target_duration)

                if segments is None:
                    continue

                for segment in segments:
                    # Normalize the segment
                    normalized_segment = normalize_piano_roll(segment)

                    # Extract musical features
                    features = extract_musical_features(normalized_segment)

                    # Quality check: skip if too sparse
                    note_density = features['avg_notes_per_time']
                    if note_density < 0.1:  # Very sparse, likely poor quality
                        continue

                    composer_data.append(normalized_segment)
                    composer_labels.append(composer_to_idx[composer])
                    composer_features.append(features)
                    segments_created += 1

                files_processed += 1

                if files_processed % 10 == 0:
                    print(f"  Processed {files_processed} files, created {segments_created} segments...")

            except Exception as e:
                print(f"  Error processing {file}: {e}")
                continue

        print(f"✅ {composer}: {files_processed} files → {segments_created} segments")

        if composer_data:
            composer_data = np.array(composer_data)
            composer_labels = np.array(composer_labels)

            all_data.append(composer_data)
            all_labels.append(composer_labels)
            all_features.extend(composer_features)

            print(f"  Final data shape: {composer_data.shape}")

    # Combine all data
    if all_data:
        data = np.concatenate(all_data, axis=0)
        labels = np.concatenate(all_labels, axis=0)

        print(f"\n🎯 FINAL IMPROVED DATASET:")
        print(f"Total samples: {len(data)}")
        print(f"Data shape: {data.shape}")
        print(f"Label distribution: {np.bincount(labels)}")

        return data, labels, all_features
    else:
        print("❌ No data loaded!")
        return None, None, None

# Load the improved dataset
print("🚀 Starting improved data loading...")
improved_data, improved_labels, features = load_improved_dataset(
    extract_path,
    TARGET_COMPOSERS,
    target_duration=45.0,  # 45 seconds per segment
    max_files_per_composer=120  # Limit for testing - remove for full dataset
)

🚀 Starting improved data loading...
🎵 LOADING DATASET WITH IMPROVED PROCESSING...
Improvements over original:
• Intelligent segment extraction for long pieces
• Better handling of piece lengths
• Musical feature extraction
• Quality filtering

--- Processing Bach ---
  Processed 10 files, created 25 segments...




  Processed 20 files, created 51 segments...
  Processed 30 files, created 77 segments...
  Processed 40 files, created 104 segments...
  Processed 50 files, created 130 segments...
  Processed 60 files, created 154 segments...
  Processed 70 files, created 182 segments...
  Processed 80 files, created 211 segments...
  Processed 90 files, created 233 segments...
  Processed 100 files, created 261 segments...
  Processed 110 files, created 286 segments...
  Processed 120 files, created 312 segments...
✅ Bach: 120 files → 312 segments
  Final data shape: (312, 128, 4500)

--- Processing Beethoven ---
  Processed 10 files, created 26 segments...
  Processed 20 files, created 52 segments...
Error processing data/kaggle/midiclassics/Beethoven/Anhang 14-3.mid: Could not decode key with 3 flats and mode 255
  Processed 30 files, created 78 segments...
  Processed 40 files, created 105 segments...
  Processed 50 files, created 126 segments...
  Processed 60 files, created 153 segments...
  Pr

In [None]:
# =====================================================
# TEST TRAINED MODELS ON IMPROVED DATASET
# =====================================================

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# Create dataset and dataloader for improved data
improved_dataset = PianoRollDataset(improved_data, improved_labels)
test_loader = DataLoader(improved_dataset, batch_size=32, shuffle=False)

def evaluate_model(model, dataloader, model_name):
    """Evaluate a single model"""
    model.eval()
    all_preds = []
    all_targets = []
    all_probs = []

    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            probs = F.softmax(outputs, dim=1)
            preds = torch.argmax(outputs, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    accuracy = 100 * np.mean(np.array(all_preds) == np.array(all_targets))

    print(f"\n📊 {model_name} Results:")
    print(f"Accuracy: {accuracy:.2f}%")
    print("\nClassification Report:")
    print(classification_report(all_targets, all_preds,
                              target_names=TARGET_COMPOSERS,
                              digits=3))

    return all_preds, all_targets, all_probs, accuracy

# Test both models
print("🧪 TESTING MODELS ON IMPROVED DATASET")
print("="*50)

# Test original model
orig_preds, targets, orig_probs, orig_acc = evaluate_model(model, test_loader, "Original Model")

# Test rhythm model
rhythm_preds, _, rhythm_probs, rhythm_acc = evaluate_model(rhythm_model, test_loader, "Rhythm Model")

# Convert to numpy arrays
orig_probs = np.array(orig_probs)
rhythm_probs = np.array(rhythm_probs)
targets = np.array(targets)

🧪 TESTING MODELS ON IMPROVED DATASET

📊 Original Model Results:
Accuracy: 68.32%

Classification Report:
              precision    recall  f1-score   support

        Bach      0.717     0.875     0.788       312
   Beethoven      0.682     0.480     0.564       304
      Chopin      0.775     0.724     0.749       286
      Mozart      0.553     0.643     0.595       244

    accuracy                          0.683      1146
   macro avg      0.682     0.681     0.674      1146
weighted avg      0.687     0.683     0.677      1146


📊 Rhythm Model Results:
Accuracy: 72.08%

Classification Report:
              precision    recall  f1-score   support

        Bach      0.839     0.817     0.828       312
   Beethoven      0.623     0.724     0.670       304
      Chopin      0.835     0.689     0.755       286
      Mozart      0.609     0.631     0.620       244

    accuracy                          0.721      1146
   macro avg      0.726     0.715     0.718      1146
weighted avg  

In [10]:
def get_piano_roll_non_overlapping_segments(midi_path, fs=100, segment_duration=45.0):
    """
    Extract non-overlapping segments to avoid data leakage with LSTM
    """
    try:
        pm = pretty_midi.PrettyMIDI(midi_path)

        # Get the actual duration of the piece
        actual_duration = pm.get_end_time()

        # If piece is very short, skip it
        if actual_duration < 15.0:  # Less than 15 seconds
            return None

        segments = []
        segment_size = segment_duration

        # Calculate number of non-overlapping segments
        num_segments = int(actual_duration // segment_size)

        # If piece is shorter than one segment, use the whole piece (padded)
        if num_segments == 0:
            piano_roll = pm.get_piano_roll(fs=fs)
            target_length = int(segment_duration * fs)

            if piano_roll.shape[1] < target_length:
                pad_width = target_length - piano_roll.shape[1]
                piano_roll = np.pad(piano_roll, ((0,0),(0,pad_width)), mode='constant')
            else:
                piano_roll = piano_roll[:, :target_length]

            segments.append(piano_roll)
            return segments

        # Extract non-overlapping segments
        for i in range(num_segments):
            start_time = i * segment_size
            end_time = start_time + segment_size

            # Create a copy and trim
            pm_segment = pretty_midi.PrettyMIDI()
            for instrument in pm.instruments:
                new_instrument = pretty_midi.Instrument(
                    program=instrument.program,
                    is_drum=instrument.is_drum,
                    name=instrument.name
                )

                for note in instrument.notes:
                    if start_time <= note.start < end_time:
                        new_note = pretty_midi.Note(
                            velocity=note.velocity,
                            pitch=note.pitch,
                            start=note.start - start_time,
                            end=min(note.end - start_time, segment_duration)
                        )
                        new_instrument.notes.append(new_note)

                if new_instrument.notes:
                    pm_segment.instruments.append(new_instrument)

            if pm_segment.instruments:
                piano_roll = pm_segment.get_piano_roll(fs=fs)
                target_length = int(segment_duration * fs)

                if piano_roll.shape[1] > target_length:
                    piano_roll = piano_roll[:, :target_length]
                else:
                    pad_width = target_length - piano_roll.shape[1]
                    piano_roll = np.pad(piano_roll, ((0,0),(0,pad_width)), mode='constant')

                segments.append(piano_roll)

        return segments if segments else None

    except Exception as e:
        print(f"Error processing {midi_path}: {e}")
        return None

print("✅ Non-overlapping segmentation function defined!")

✅ Non-overlapping segmentation function defined!


In [11]:
def load_dataset_non_overlapping_full(extract_path, target_composers, segment_duration=45.0):
    """
    Load FULL dataset with non-overlapping segments - using ALL available files
    """
    print("🎵 LOADING FULL DATASET WITH NON-OVERLAPPING SEGMENTS...")
    print("Benefits for LSTM training:")
    print(f"• Segment duration: {segment_duration}s (no overlap)")
    print("• Using ALL available files (no limits)")
    print("• No data leakage between train/test")
    print("• Cleaner temporal boundaries")
    print("• Will address class imbalance later during training")

    composer_to_idx = {c: i for i, c in enumerate(target_composers)}
    all_data = []
    all_labels = []
    all_features = []

    total_files_processed = 0
    total_segments_created = 0

    for composer in target_composers:
        print(f"\n--- Processing {composer} ---")
        composer_dir = os.path.join(extract_path, composer)

        if not os.path.isdir(composer_dir):
            print(f"Directory not found: {composer_dir}")
            continue

        composer_data = []
        composer_labels = []
        composer_features = []
        files_processed = 0
        segments_created = 0

        # Get ALL MIDI files - no limit
        midi_files = [f for f in os.listdir(composer_dir)
                     if f.lower().endswith(('.mid', '.midi'))]

        print(f"  Found {len(midi_files)} MIDI files for {composer}")

        for file in midi_files:
            midi_path = os.path.join(composer_dir, file)

            try:
                # Use non-overlapping segmentation
                segments = get_piano_roll_non_overlapping_segments(
                    midi_path,
                    segment_duration=segment_duration
                )

                if segments is None:
                    continue

                for segment in segments:
                    # Normalize the segment
                    normalized_segment = normalize_piano_roll(segment)

                    # Extract musical features
                    features = extract_musical_features(normalized_segment)

                    # Quality check: skip if too sparse
                    note_density = features['avg_notes_per_time']
                    if note_density < 0.1:  # Very sparse, likely poor quality
                        continue

                    composer_data.append(normalized_segment)
                    composer_labels.append(composer_to_idx[composer])
                    composer_features.append(features)
                    segments_created += 1

                files_processed += 1

                # Progress update every 50 files for full dataset
                if files_processed % 50 == 0:
                    print(f"  Processed {files_processed}/{len(midi_files)} files, created {segments_created} segments...")

            except Exception as e:
                print(f"  Error processing {file}: {e}")
                continue

        print(f"✅ {composer}: {files_processed}/{len(midi_files)} files → {segments_created} segments")

        if composer_data:
            composer_data = np.array(composer_data)
            composer_labels = np.array(composer_labels)

            all_data.append(composer_data)
            all_labels.append(composer_labels)
            all_features.extend(composer_features)

            print(f"  Final data shape: {composer_data.shape}")

        total_files_processed += files_processed
        total_segments_created += segments_created

    # Combine all data
    if all_data:
        data = np.concatenate(all_data, axis=0)
        labels = np.concatenate(all_labels, axis=0)

        print(f"\n🎯 FINAL FULL NON-OVERLAPPING DATASET:")
        print(f"Total files processed: {total_files_processed}")
        print(f"Total samples: {len(data)}")
        print(f"Data shape: {data.shape}")
        print(f"Label distribution: {np.bincount(labels)}")

        # Show class distribution percentages
        for i, composer in enumerate(target_composers):
            count = np.sum(labels == i)
            percentage = (count / len(labels)) * 100
            print(f"  {composer}: {count} samples ({percentage:.1f}%)")

        return data, labels, all_features
    else:
        print("❌ No data loaded!")
        return None, None, None

# Load the FULL dataset with non-overlapping segments
print("🚀 Starting FULL non-overlapping segment data loading...")
full_data, full_labels, full_features = load_dataset_non_overlapping_full(
    extract_path,
    TARGET_COMPOSERS,
    segment_duration=45.0  # 45-second segments, no overlap, ALL files
)

🚀 Starting FULL non-overlapping segment data loading...
🎵 LOADING FULL DATASET WITH NON-OVERLAPPING SEGMENTS...
Benefits for LSTM training:
• Segment duration: 45.0s (no overlap)
• Using ALL available files (no limits)
• No data leakage between train/test
• Cleaner temporal boundaries
• Will address class imbalance later during training

--- Processing Bach ---
  Found 131 MIDI files for Bach
  Processed 50/131 files, created 394 segments...
  Processed 100/131 files, created 715 segments...
✅ Bach: 131/131 files → 977 segments
  Final data shape: (977, 128, 4500)

--- Processing Beethoven ---
  Found 134 MIDI files for Beethoven
Error processing data/kaggle/midiclassics/Beethoven/Anhang 14-3.mid: Could not decode key with 3 flats and mode 255
  Processed 50/134 files, created 357 segments...
  Processed 100/134 files, created 765 segments...
✅ Beethoven: 133/134 files → 987 segments
  Final data shape: (987, 128, 4500)

--- Processing Chopin ---
  Found 136 MIDI files for Chopin
  Pro

In [12]:
# =====================================================
# AGGRESSIVE CNN-LSTM-TRANSFORMER FOR A100 40GB
# =====================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    """Enhanced positional encoding for transformer"""
    def __init__(self, d_model, max_len=10000):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)

        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class AggressiveCNN_LSTM_Transformer(nn.Module):
    def __init__(self, num_classes=4, lstm_hidden=512, transformer_dim=1024, num_heads=16, num_layers=8):
        super(AggressiveCNN_LSTM_Transformer, self).__init__()

        print("🚀 Building AGGRESSIVE CNN-LSTM-Transformer for A100 40GB...")
        print(f"• Deep CNN feature extraction (6 blocks)")
        print(f"• Large LSTM temporal modeling (hidden: {lstm_hidden})")
        print(f"• Deep Transformer self-attention (dim: {transformer_dim}, heads: {num_heads}, layers: {num_layers})")
        print(f"• Multi-scale feature fusion")
        print(f"• Advanced attention mechanisms")

        # ==========================================
        # DEEP CNN BACKBONE - 6 BLOCKS
        # ==========================================

        # Block 1: Initial feature extraction
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout2d(0.1),
            nn.MaxPool2d(kernel_size=(2, 2))  # 128x64 -> 64x32
        )

        # Block 2: Deeper features
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout2d(0.15),
            nn.MaxPool2d(kernel_size=(2, 2))  # 64x32 -> 32x16
        )

        # Block 3: More complex patterns
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.MaxPool2d(kernel_size=(2, 2))  # 32x16 -> 16x8
        )

        # Block 4: High-level features
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Dropout2d(0.25),
            nn.MaxPool2d(kernel_size=(2, 2))  # 16x8 -> 8x4
        )

        # Block 5: Abstract features
        self.conv5 = nn.Sequential(
            nn.Conv2d(512, 768, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(768),
            nn.ReLU(),
            nn.Conv2d(768, 768, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(768),
            nn.ReLU(),
            nn.Dropout2d(0.3),
            nn.MaxPool2d(kernel_size=(2, 2))  # 8x4 -> 4x2
        )

        # Block 6: Final feature extraction
        self.conv6 = nn.Sequential(
            nn.Conv2d(768, 1024, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.Conv2d(1024, 1024, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.Dropout2d(0.35),
            nn.AdaptiveAvgPool2d((2, 1))  # Ensure consistent output: 2x1
        )

        # ==========================================
        # LARGE BIDIRECTIONAL LSTM
        # ==========================================
        self.feature_size = 1024 * 2  # 1024 channels * 2x1 spatial
        self.lstm_hidden = lstm_hidden

        # Multi-layer LSTM with larger capacity
        self.lstm = nn.LSTM(
            input_size=self.feature_size,
            hidden_size=lstm_hidden,
            num_layers=4,  # Deeper LSTM
            batch_first=True,
            dropout=0.3,
            bidirectional=True
        )

        # Additional LSTM for temporal refinement
        self.lstm_refine = nn.LSTM(
            input_size=lstm_hidden * 2,
            hidden_size=lstm_hidden // 2,
            num_layers=2,
            batch_first=True,
            dropout=0.2,
            bidirectional=True
        )

        # ==========================================
        # DEEP TRANSFORMER ENCODER
        # ==========================================
        self.transformer_dim = transformer_dim

        # Project LSTM output to transformer dimension
        self.lstm_to_transformer = nn.Sequential(
            nn.Linear(lstm_hidden, transformer_dim),
            nn.LayerNorm(transformer_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )

        # Positional encoding
        self.pos_encoding = PositionalEncoding(transformer_dim, max_len=10000)

        # Deep transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=transformer_dim,
            nhead=num_heads,
            dim_feedforward=transformer_dim * 4,  # Large feedforward
            dropout=0.1,
            activation='gelu',  # GELU activation for better performance
            batch_first=True,
            norm_first=True  # Pre-norm for better training stability
        )

        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers,
            norm=nn.LayerNorm(transformer_dim)
        )

        # ==========================================
        # MULTI-SCALE ATTENTION & FUSION
        # ==========================================

        # Multi-scale attention heads
        self.global_attention = nn.MultiheadAttention(
            embed_dim=transformer_dim,
            num_heads=num_heads,
            dropout=0.1,
            batch_first=True
        )

        self.local_attention = nn.MultiheadAttention(
            embed_dim=transformer_dim,
            num_heads=num_heads // 2,
            dropout=0.1,
            batch_first=True
        )

        # Cross-attention between global and local features
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=transformer_dim,
            num_heads=num_heads // 2,
            dropout=0.1,
            batch_first=True
        )

        # Feature fusion
        self.feature_fusion = nn.Sequential(
            nn.Linear(transformer_dim * 3, transformer_dim),
            nn.LayerNorm(transformer_dim),
            nn.GELU(),
            nn.Dropout(0.2)
        )

        # ==========================================
        # ADVANCED CLASSIFICATION HEAD
        # ==========================================

        # Hierarchical classification with multiple paths
        self.classifier = nn.Sequential(
            nn.LayerNorm(transformer_dim),
            nn.Linear(transformer_dim, 2048),
            nn.GELU(),
            nn.Dropout(0.5),

            nn.Linear(2048, 1024),
            nn.GELU(),
            nn.Dropout(0.4),

            nn.Linear(1024, 512),
            nn.GELU(),
            nn.Dropout(0.3),

            nn.Linear(512, 256),
            nn.GELU(),
            nn.Dropout(0.2),

            nn.Linear(256, num_classes)
        )

        # Auxiliary classifier for regularization
        self.aux_classifier = nn.Sequential(
            nn.Linear(lstm_hidden, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

        print("✅ AGGRESSIVE CNN-LSTM-Transformer architecture built!")

    def forward(self, x):
        batch_size = x.size(0)

        # ==========================================
        # DEEP CNN FEATURE EXTRACTION
        # ==========================================
        x = self.conv1(x)      # (batch, 64, 64, T/2)
        x = self.conv2(x)      # (batch, 128, 32, T/4)
        x = self.conv3(x)      # (batch, 256, 16, T/8)
        x = self.conv4(x)      # (batch, 512, 8, T/16)
        x = self.conv5(x)      # (batch, 768, 4, T/32)
        x = self.conv6(x)      # (batch, 1024, 2, T/64)

        # Reshape for LSTM: (batch, time_steps, features)
        x = x.permute(0, 3, 1, 2)  # (batch, T/64, 1024, 2)
        x = x.contiguous().view(batch_size, x.size(1), -1)  # (batch, T/64, 1024*2)

        # ==========================================
        # LARGE LSTM PROCESSING
        # ==========================================
        lstm_out, _ = self.lstm(x)  # (batch, T/64, lstm_hidden*2)
        lstm_refined, _ = self.lstm_refine(lstm_out)  # (batch, T/64, lstm_hidden)

        # Auxiliary classification from LSTM features (for regularization)
        lstm_pooled = torch.mean(lstm_refined, dim=1)
        aux_output = self.aux_classifier(lstm_pooled)

        # ==========================================
        # DEEP TRANSFORMER PROCESSING
        # ==========================================

        # Project to transformer dimension
        transformer_input = self.lstm_to_transformer(lstm_refined)  # (batch, T/64, transformer_dim)

        # Add positional encoding
        transformer_input = transformer_input.transpose(0, 1)  # (T/64, batch, transformer_dim)
        transformer_input = self.pos_encoding(transformer_input)
        transformer_input = transformer_input.transpose(0, 1)  # (batch, T/64, transformer_dim)

        # Deep transformer encoding
        transformer_out = self.transformer_encoder(transformer_input)  # (batch, T/64, transformer_dim)

        # ==========================================
        # MULTI-SCALE ATTENTION & FUSION
        # ==========================================

        # Global attention (full sequence)
        global_attended, _ = self.global_attention(
            transformer_out, transformer_out, transformer_out
        )

        # Local attention (sliding window - simulate by chunking)
        seq_len = transformer_out.size(1)
        if seq_len > 16:
            # Use overlapping windows
            local_features = []
            window_size = min(16, seq_len)
            for i in range(0, max(1, seq_len - window_size + 1), window_size // 2):
                end_idx = min(i + window_size, seq_len)
                window = transformer_out[:, i:end_idx, :]
                local_att, _ = self.local_attention(window, window, window)
                local_features.append(torch.mean(local_att, dim=1, keepdim=True))
            local_attended = torch.cat(local_features, dim=1)
        else:
            local_attended, _ = self.local_attention(
                transformer_out, transformer_out, transformer_out
            )

        # Cross attention between global and local
        cross_attended, _ = self.cross_attention(
            global_attended, local_attended, local_attended
        )

        # Fusion of multi-scale features
        # Pool to same size for concatenation
        global_pooled = torch.mean(global_attended, dim=1)
        local_pooled = torch.mean(local_attended, dim=1)
        cross_pooled = torch.mean(cross_attended, dim=1)

        fused_features = torch.cat([global_pooled, local_pooled, cross_pooled], dim=1)
        final_features = self.feature_fusion(fused_features)

        # ==========================================
        # CLASSIFICATION
        # ==========================================
        main_output = self.classifier(final_features)

        return main_output, aux_output

# Create the aggressive model
print("🚀 Creating AGGRESSIVE model for A100 40GB...")
aggressive_model = AggressiveCNN_LSTM_Transformer(
    num_classes=4,
    lstm_hidden=512,        # Doubled from 256
    transformer_dim=1024,   # Doubled from 512
    num_heads=16,          # Doubled from 8
    num_layers=8           # Doubled from 4
).to(device)

# Test with dummy input
test_input = torch.randn(4, 1, 128, 4500).to(device)  # Larger batch
print(f"Input shape: {test_input.shape}")

with torch.no_grad():
    main_out, aux_out = aggressive_model(test_input)
    print(f"Main output shape: {main_out.shape}")
    print(f"Auxiliary output shape: {aux_out.shape}")
    print(f"✅ Aggressive model forward pass successful!")

# Count parameters
total_params = sum(p.numel() for p in aggressive_model.parameters())
trainable_params = sum(p.numel() for p in aggressive_model.parameters() if p.requires_grad)

print(f"\n📊 AGGRESSIVE Model Statistics:")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: ~{total_params * 4 / 1024 / 1024:.1f} MB")
print(f"Estimated GPU memory (training): ~{total_params * 16 / 1024 / 1024:.1f} MB")

🚀 Creating AGGRESSIVE model for A100 40GB...
🚀 Building AGGRESSIVE CNN-LSTM-Transformer for A100 40GB...
• Deep CNN feature extraction (6 blocks)
• Large LSTM temporal modeling (hidden: 512)
• Deep Transformer self-attention (dim: 1024, heads: 16, layers: 8)
• Multi-scale feature fusion
• Advanced attention mechanisms




✅ AGGRESSIVE CNN-LSTM-Transformer architecture built!
Input shape: torch.Size([4, 1, 128, 4500])
Main output shape: torch.Size([4, 4])
Auxiliary output shape: torch.Size([4, 4])
✅ Aggressive model forward pass successful!

📊 AGGRESSIVE Model Statistics:
Total parameters: 185,688,776
Trainable parameters: 185,688,776
Model size: ~708.3 MB
Estimated GPU memory (training): ~2833.4 MB


In [13]:
# =====================================================
# SEQUENCE-AWARE DATA LOADING WITH PIECE TRACKING
# =====================================================

def load_dataset_with_piece_tracking(extract_path, target_composers, segment_duration=45.0):
    """
    Load dataset while tracking which segments belong to the same piece
    This enables sequence-aware training that distinguishes:
    - Consecutive segments from SAME piece vs DIFFERENT pieces
    - Same composer vs different composer relationships
    """
    print("🎵 LOADING DATASET WITH PIECE TRACKING...")
    print("Key improvements:")
    print(f"• Track piece identity for each segment")
    print(f"• Enable consecutive segment modeling")
    print(f"• Support contrastive learning approaches")
    print(f"• Segment duration: {segment_duration}s")

    composer_to_idx = {c: i for i, c in enumerate(target_composers)}
    all_data = []
    all_labels = []
    all_piece_ids = []  # NEW: Track which piece each segment comes from
    all_piece_names = []  # NEW: Track piece names for analysis

    piece_id_counter = 0
    total_files_processed = 0
    total_segments_created = 0

    for composer in target_composers:
        print(f"\n--- Processing {composer} ---")
        composer_dir = os.path.join(extract_path, composer)

        if not os.path.isdir(composer_dir):
            print(f"Directory not found: {composer_dir}")
            continue

        midi_files = [f for f in os.listdir(composer_dir)
                     if f.lower().endswith(('.mid', '.midi'))]

        print(f"  Found {len(midi_files)} MIDI files for {composer}")
        files_processed = 0
        segments_created = 0

        for file in midi_files:
            midi_path = os.path.join(composer_dir, file)

            try:
                # Use non-overlapping segmentation
                segments = get_piano_roll_non_overlapping_segments(
                    midi_path, segment_duration=segment_duration
                )

                if segments is None or len(segments) == 0:
                    continue

                # All segments from this file get the same piece_id
                current_piece_id = piece_id_counter
                piece_name = f"{composer}_{file}"
                piece_id_counter += 1

                valid_segments_from_piece = 0

                for segment_idx, segment in enumerate(segments):
                    # Normalize the segment
                    normalized_segment = normalize_piano_roll(segment)

                    # Extract musical features
                    features = extract_musical_features(normalized_segment)

                    # Quality check: skip if too sparse
                    if features['avg_notes_per_time'] < 0.1:
                        continue

                    all_data.append(normalized_segment)
                    all_labels.append(composer_to_idx[composer])
                    all_piece_ids.append(current_piece_id)
                    all_piece_names.append(f"{piece_name}_seg{segment_idx}")

                    valid_segments_from_piece += 1
                    segments_created += 1

                if valid_segments_from_piece > 0:
                    files_processed += 1

                # Progress update
                if files_processed % 50 == 0:
                    print(f"  Processed {files_processed}/{len(midi_files)} files, created {segments_created} segments...")

            except Exception as e:
                print(f"  Error processing {file}: {e}")
                continue

        print(f"✅ {composer}: {files_processed}/{len(midi_files)} files → {segments_created} segments")
        total_files_processed += files_processed
        total_segments_created += segments_created

    # Convert to numpy arrays
    data = np.array(all_data)
    labels = np.array(all_labels)
    piece_ids = np.array(all_piece_ids)

    print(f"\n🎯 FINAL PIECE-TRACKED DATASET:")
    print(f"Total files processed: {total_files_processed}")
    print(f"Total segments: {len(data)}")
    print(f"Total unique pieces: {len(np.unique(piece_ids))}")
    print(f"Data shape: {data.shape}")
    print(f"Label distribution: {np.bincount(labels)}")

    # Analyze segments per piece
    segments_per_piece = []
    for piece_id in np.unique(piece_ids):
        count = np.sum(piece_ids == piece_id)
        segments_per_piece.append(count)

    print(f"Segments per piece - Mean: {np.mean(segments_per_piece):.1f}, "
          f"Min: {np.min(segments_per_piece)}, Max: {np.max(segments_per_piece)}")

    # Show class distribution percentages
    for i, composer in enumerate(target_composers):
        count = np.sum(labels == i)
        percentage = (count / len(labels)) * 100
        print(f"  {composer}: {count} segments ({percentage:.1f}%)")

    return data, labels, piece_ids, all_piece_names

# Load the dataset with piece tracking
print("🚀 Starting piece-tracked data loading...")
tracked_data, tracked_labels, tracked_piece_ids, piece_names = load_dataset_with_piece_tracking(
    extract_path,
    TARGET_COMPOSERS,
    segment_duration=45.0
)

🚀 Starting piece-tracked data loading...
🎵 LOADING DATASET WITH PIECE TRACKING...
Key improvements:
• Track piece identity for each segment
• Enable consecutive segment modeling
• Support contrastive learning approaches
• Segment duration: 45.0s

--- Processing Bach ---
  Found 131 MIDI files for Bach
  Processed 50/131 files, created 394 segments...
  Processed 100/131 files, created 715 segments...
✅ Bach: 131/131 files → 977 segments

--- Processing Beethoven ---
  Found 134 MIDI files for Beethoven
Error processing data/kaggle/midiclassics/Beethoven/Anhang 14-3.mid: Could not decode key with 3 flats and mode 255
  Processed 50/134 files, created 357 segments...
  Processed 100/134 files, created 765 segments...
✅ Beethoven: 133/134 files → 987 segments

--- Processing Chopin ---
  Found 136 MIDI files for Chopin
  Processed 50/136 files, created 218 segments...
  Processed 100/136 files, created 449 segments...
✅ Chopin: 136/136 files → 610 segments

--- Processing Mozart ---
  Fou

In [14]:
# =====================================================
# CLASS WEIGHTS IMPLEMENTATION
# =====================================================

from sklearn.utils.class_weight import compute_class_weight
import torch.nn as nn

# Compute class weights for our imbalanced dataset
def compute_class_weights(labels, target_composers):
    """
    Compute class weights to handle imbalance
    """
    print("⚖️ COMPUTING CLASS WEIGHTS...")

    # Get unique labels and compute balanced weights
    unique_labels = np.unique(labels)
    class_weights = compute_class_weight(
        'balanced',
        classes=unique_labels,
        y=labels
    )

    print(f"Original distribution:")
    for i, composer in enumerate(target_composers):
        count = np.sum(labels == i)
        percentage = (count / len(labels)) * 100
        print(f"  {composer}: {count} samples ({percentage:.1f}%)")

    print(f"\nComputed class weights:")
    for i, (composer, weight) in enumerate(zip(target_composers, class_weights)):
        print(f"  {composer}: {weight:.3f}")

    # Convert to PyTorch tensor
    class_weights_tensor = torch.FloatTensor(class_weights).to(device)

    print(f"\n✅ Class weights ready for CrossEntropyLoss")
    return class_weights_tensor

# Compute class weights for our tracked dataset
class_weights = compute_class_weights(tracked_labels, TARGET_COMPOSERS)

⚖️ COMPUTING CLASS WEIGHTS...
Original distribution:
  Bach: 977 samples (31.5%)
  Beethoven: 987 samples (31.9%)
  Chopin: 610 samples (19.7%)
  Mozart: 524 samples (16.9%)

Computed class weights:
  Bach: 0.793
  Beethoven: 0.785
  Chopin: 1.270
  Mozart: 1.478

✅ Class weights ready for CrossEntropyLoss


In [15]:
# =====================================================
# SEQUENCE-AWARE DATASET FOR CONSECUTIVE SEGMENTS
# =====================================================

from torch.utils.data import Dataset
import random

class SequenceAwareDataset(Dataset):
    """
    Dataset that creates sequences of consecutive segments from same pieces
    This enables the model to learn temporal relationships within compositions
    """
    def __init__(self, data, labels, piece_ids, sequence_length=2, include_singles=True):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
        self.piece_ids = torch.tensor(piece_ids, dtype=torch.long)
        self.sequence_length = sequence_length
        self.include_singles = include_singles

        # Group segments by piece
        self.piece_segments = {}
        for idx, piece_id in enumerate(piece_ids):
            piece_id = int(piece_id)
            if piece_id not in self.piece_segments:
                self.piece_segments[piece_id] = []
            self.piece_segments[piece_id].append(idx)

        # Create sequence indices
        self.sequence_indices = self._create_sequence_indices()

        print(f"📊 SEQUENCE DATASET CREATED:")
        print(f"• Sequence length: {sequence_length}")
        print(f"• Total pieces: {len(self.piece_segments)}")
        print(f"• Total sequences: {len(self.sequence_indices)}")
        print(f"• Include single segments: {include_singles}")

    def _create_sequence_indices(self):
        """Create indices for all possible sequences"""
        sequences = []

        # Add consecutive sequences from same pieces
        for piece_id, segment_indices in self.piece_segments.items():
            if len(segment_indices) >= self.sequence_length:
                # Create all possible consecutive sequences from this piece
                for start_idx in range(len(segment_indices) - self.sequence_length + 1):
                    seq_indices = segment_indices[start_idx:start_idx + self.sequence_length]
                    sequences.append({
                        'type': 'sequence',
                        'indices': seq_indices,
                        'piece_id': piece_id,
                        'label': int(self.labels[seq_indices[0]])  # All should have same label
                    })

        # Optionally add single segments
        if self.include_singles:
            for piece_id, segment_indices in self.piece_segments.items():
                for idx in segment_indices:
                    sequences.append({
                        'type': 'single',
                        'indices': [idx],
                        'piece_id': piece_id,
                        'label': int(self.labels[idx])
                    })

        return sequences

    def __len__(self):
        return len(self.sequence_indices)

    def __getitem__(self, idx):
        sequence_info = self.sequence_indices[idx]
        indices = sequence_info['indices']

        if len(indices) == 1:
            # Single segment
            segment = self.data[indices[0]].unsqueeze(0)  # Add channel dim
            return segment, sequence_info['label'], sequence_info['piece_id'], 'single'
        else:
            # Multiple segments - stack them
            segments = []
            for seg_idx in indices:
                segments.append(self.data[seg_idx].unsqueeze(0))  # Add channel dim

            # Stack segments: (sequence_length, 1, 128, T)
            sequence = torch.stack(segments, dim=0)

            return sequence, sequence_info['label'], sequence_info['piece_id'], 'sequence'

# Create sequence-aware dataset
print("🔗 Creating sequence-aware dataset...")
sequence_dataset = SequenceAwareDataset(
    tracked_data,
    tracked_labels,
    tracked_piece_ids,
    sequence_length=2,  # Start with pairs
    include_singles=True
)

🔗 Creating sequence-aware dataset...
📊 SEQUENCE DATASET CREATED:
• Sequence length: 2
• Total pieces: 490
• Total sequences: 5706
• Include single segments: True


In [16]:
# =====================================================
# TRAINING SETUP WITH CLASS WEIGHTS
# =====================================================

import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from sklearn.model_selection import train_test_split

def create_balanced_training_setup(data, labels, piece_ids, class_weights, test_size=0.2, val_size=0.15):
    """
    Create training setup with class weights for imbalance handling
    """
    print("🎯 CREATING BALANCED TRAINING SETUP...")
    print(f"Using class weights: {class_weights}")

    # Split data at the piece level to avoid data leakage
    unique_pieces = np.unique(piece_ids)
    piece_labels = np.array([labels[piece_ids == pid][0] for pid in unique_pieces])

    # Split pieces into train/val/test
    train_pieces, test_pieces, _, _ = train_test_split(
        unique_pieces, piece_labels,
        test_size=test_size,
        stratify=piece_labels,
        random_state=42
    )

    train_pieces, val_pieces, _, _ = train_test_split(
        train_pieces, piece_labels[np.isin(unique_pieces, train_pieces)],
        test_size=val_size/(1-test_size),
        stratify=piece_labels[np.isin(unique_pieces, train_pieces)],
        random_state=42
    )

    # Create masks for train/val/test
    train_mask = np.isin(piece_ids, train_pieces)
    val_mask = np.isin(piece_ids, val_pieces)
    test_mask = np.isin(piece_ids, test_pieces)

    print(f"📊 DATA SPLIT:")
    print(f"Train pieces: {len(train_pieces)} | segments: {np.sum(train_mask)}")
    print(f"Val pieces:   {len(val_pieces)} | segments: {np.sum(val_mask)}")
    print(f"Test pieces:  {len(test_pieces)} | segments: {np.sum(test_mask)}")

    # Show class distribution per split
    for split_name, mask in [("Train", train_mask), ("Val", val_mask), ("Test", test_mask)]:
        split_labels = labels[mask]
        print(f"\n{split_name} distribution:")
        for i, composer in enumerate(TARGET_COMPOSERS):
            count = np.sum(split_labels == i)
            percentage = (count / len(split_labels)) * 100 if len(split_labels) > 0 else 0
            print(f"  {composer}: {count} ({percentage:.1f}%)")

    # Create datasets
    train_dataset = PianoRollDataset(data[train_mask], labels[train_mask])
    val_dataset = PianoRollDataset(data[val_mask], labels[val_mask])
    test_dataset = PianoRollDataset(data[test_mask], labels[test_mask])

    return train_dataset, val_dataset, test_dataset, train_mask, val_mask, test_mask

# Create the balanced training setup
train_dataset, val_dataset, test_dataset, train_mask, val_mask, test_mask = create_balanced_training_setup(
    tracked_data, tracked_labels, tracked_piece_ids, class_weights
)

# Create weighted loss function
weighted_criterion = nn.CrossEntropyLoss(weight=class_weights)
print(f"\n✅ Weighted loss function created with class weights")

# Create data loaders
batch_size = 16  # Smaller batch size for stability
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

print(f"\n📚 DATA LOADERS CREATED:")
print(f"• Batch size: {batch_size}")
print(f"• Train batches: {len(train_loader)}")
print(f"• Val batches: {len(val_loader)}")
print(f"• Test batches: {len(test_loader)}")

🎯 CREATING BALANCED TRAINING SETUP...
Using class weights: tensor([0.7927, 0.7847, 1.2697, 1.4781], device='cuda:0')
📊 DATA SPLIT:
Train pieces: 318 | segments: 1986
Val pieces:   74 | segments: 438
Test pieces:  98 | segments: 674

Train distribution:
  Bach: 606 (30.5%)
  Beethoven: 600 (30.2%)
  Chopin: 443 (22.3%)
  Mozart: 337 (17.0%)

Val distribution:
  Bach: 154 (35.2%)
  Beethoven: 160 (36.5%)
  Chopin: 53 (12.1%)
  Mozart: 71 (16.2%)

Test distribution:
  Bach: 217 (32.2%)
  Beethoven: 227 (33.7%)
  Chopin: 114 (16.9%)
  Mozart: 116 (17.2%)

✅ Weighted loss function created with class weights

📚 DATA LOADERS CREATED:
• Batch size: 16
• Train batches: 125
• Val batches: 28
• Test batches: 43


In [17]:
# =====================================================
# ADVANCED TRAINING LOOP FOR AGGRESSIVE MODEL
# =====================================================

import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ReduceLROnPlateau
import time
from collections import defaultdict
import matplotlib.pyplot as plt

def train_aggressive_model(model, train_loader, val_loader, class_weights,
                          epochs=50, initial_lr=1e-3, save_path='aggressive_model.pth'):
    """
    Advanced training with:
    - Auxiliary loss for regularization
    - Cosine annealing with warm restarts
    - Mixed precision training
    - Advanced metrics tracking
    - Early stopping with patience
    """
    print("🚀 STARTING AGGRESSIVE MODEL TRAINING...")
    print(f"• Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"• Training samples: {len(train_loader.dataset)}")
    print(f"• Validation samples: {len(val_loader.dataset)}")
    print(f"• Epochs: {epochs}")
    print(f"• Initial learning rate: {initial_lr}")
    print(f"• Using mixed precision: True")
    print(f"• Class weights: {class_weights}")

    # Setup optimizers and schedulers
    optimizer = optim.AdamW(
        model.parameters(),
        lr=initial_lr,
        weight_decay=1e-4,
        betas=(0.9, 0.999),
        eps=1e-8
    )

    # Cosine annealing with warm restarts
    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=10,  # Restart every 10 epochs
        T_mult=2,  # Double the period after each restart
        eta_min=1e-6
    )

    # Mixed precision scaler
    scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None

    # Loss functions
    main_criterion = nn.CrossEntropyLoss(weight=class_weights)
    aux_criterion = nn.CrossEntropyLoss(weight=class_weights)

    # Training tracking
    history = {
        'train_loss': [], 'train_acc': [], 'train_main_loss': [], 'train_aux_loss': [],
        'val_loss': [], 'val_acc': [], 'val_main_loss': [], 'val_aux_loss': [],
        'lr': [], 'epoch_time': []
    }

    best_val_acc = 0.0
    patience_counter = 0
    patience = 15  # Early stopping patience

    print("\n🎯 Training Configuration:")
    print(f"• Optimizer: AdamW (lr={initial_lr}, weight_decay=1e-4)")
    print(f"• Scheduler: CosineAnnealingWarmRestarts (T_0=10, T_mult=2)")
    print(f"• Main loss weight: 0.7, Auxiliary loss weight: 0.3")
    print(f"• Early stopping patience: {patience}")
    print(f"• Mixed precision: {'CUDA' if scaler else 'Disabled'}")

    model.to(device)

    for epoch in range(epochs):
        epoch_start_time = time.time()

        # ==========================================
        # TRAINING PHASE
        # ==========================================
        model.train()
        train_metrics = defaultdict(float)
        train_correct = 0
        train_total = 0

        print(f"\n📈 Epoch {epoch+1}/{epochs}")
        print("-" * 50)

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()

            # Mixed precision forward pass
            if scaler:
                with torch.cuda.amp.autocast():
                    main_output, aux_output = model(data)
                    main_loss = main_criterion(main_output, target)
                    aux_loss = aux_criterion(aux_output, target)
                    # Combine losses with weights
                    total_loss = 0.7 * main_loss + 0.3 * aux_loss

                # Mixed precision backward pass
                scaler.scale(total_loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                main_output, aux_output = model(data)
                main_loss = main_criterion(main_output, target)
                aux_loss = aux_criterion(aux_output, target)
                total_loss = 0.7 * main_loss + 0.3 * aux_loss

                total_loss.backward()
                optimizer.step()

            # Track metrics
            train_metrics['total_loss'] += total_loss.item()
            train_metrics['main_loss'] += main_loss.item()
            train_metrics['aux_loss'] += aux_loss.item()

            # Accuracy from main output
            _, predicted = torch.max(main_output.data, 1)
            train_total += target.size(0)
            train_correct += (predicted == target).sum().item()

            # Progress update
            if batch_idx % 20 == 0:
                current_lr = optimizer.param_groups[0]['lr']
                print(f"  Batch {batch_idx:3d}/{len(train_loader)} | "
                      f"Loss: {total_loss.item():.4f} | "
                      f"Acc: {100.*train_correct/train_total:.2f}% | "
                      f"LR: {current_lr:.2e}")

        # ==========================================
        # VALIDATION PHASE
        # ==========================================
        model.eval()
        val_metrics = defaultdict(float)
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)

                if scaler:
                    with torch.cuda.amp.autocast():
                        main_output, aux_output = model(data)
                        main_loss = main_criterion(main_output, target)
                        aux_loss = aux_criterion(aux_output, target)
                        total_loss = 0.7 * main_loss + 0.3 * aux_loss
                else:
                    main_output, aux_output = model(data)
                    main_loss = main_criterion(main_output, target)
                    aux_loss = aux_criterion(aux_output, target)
                    total_loss = 0.7 * main_loss + 0.3 * aux_loss

                val_metrics['total_loss'] += total_loss.item()
                val_metrics['main_loss'] += main_loss.item()
                val_metrics['aux_loss'] += aux_loss.item()

                _, predicted = torch.max(main_output.data, 1)
                val_total += target.size(0)
                val_correct += (predicted == target).sum().item()

        # Calculate epoch metrics
        train_loss = train_metrics['total_loss'] / len(train_loader)
        train_acc = 100. * train_correct / train_total
        val_loss = val_metrics['total_loss'] / len(val_loader)
        val_acc = 100. * val_correct / val_total

        # Update learning rate
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']

        # Record history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['train_main_loss'].append(train_metrics['main_loss'] / len(train_loader))
        history['train_aux_loss'].append(train_metrics['aux_loss'] / len(train_loader))

        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['val_main_loss'].append(val_metrics['main_loss'] / len(val_loader))
        history['val_aux_loss'].append(val_metrics['aux_loss'] / len(val_loader))

        history['lr'].append(current_lr)

        epoch_time = time.time() - epoch_start_time
        history['epoch_time'].append(epoch_time)

        # Print epoch summary
        print(f"\n📊 Epoch {epoch+1} Summary:")
        print(f"  Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
        print(f"  Val:   Loss={val_loss:.4f}, Acc={val_acc:.2f}%")
        print(f"  LR: {current_lr:.2e}, Time: {epoch_time:.1f}s")

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_acc': best_val_acc,
                'history': history
            }, save_path)
            print(f"  💾 New best model saved! Val Acc: {val_acc:.2f}%")
        else:
            patience_counter += 1
            print(f"  ⏳ Patience: {patience_counter}/{patience}")

        # Early stopping
        if patience_counter >= patience:
            print(f"\n🛑 Early stopping triggered! Best Val Acc: {best_val_acc:.2f}%")
            break

        # Memory cleanup
        if device.type == 'cuda':
            torch.cuda.empty_cache()

    print(f"\n✅ Training completed!")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")
    print(f"Model saved to: {save_path}")

    return history

# Create save directory
os.makedirs('saved_models', exist_ok=True)

# Start training the aggressive model
print("🚀 Starting training of AGGRESSIVE CNN-LSTM-Transformer...")
history = train_aggressive_model(
    aggressive_model,
    train_loader,
    val_loader,
    class_weights,
    epochs=50,
    initial_lr=1e-3,
    save_path='saved_models/aggressive_cnn_lstm_transformer.pth'
)

🚀 Starting training of AGGRESSIVE CNN-LSTM-Transformer...
🚀 STARTING AGGRESSIVE MODEL TRAINING...
• Model parameters: 185,688,776
• Training samples: 1986
• Validation samples: 438
• Epochs: 50
• Initial learning rate: 0.001
• Using mixed precision: True
• Class weights: tensor([0.7927, 0.7847, 1.2697, 1.4781], device='cuda:0')


  scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None
  with torch.cuda.amp.autocast():



🎯 Training Configuration:
• Optimizer: AdamW (lr=0.001, weight_decay=1e-4)
• Scheduler: CosineAnnealingWarmRestarts (T_0=10, T_mult=2)
• Main loss weight: 0.7, Auxiliary loss weight: 0.3
• Early stopping patience: 15
• Mixed precision: CUDA

📈 Epoch 1/50
--------------------------------------------------
  Batch   0/125 | Loss: 1.3886 | Acc: 18.75% | LR: 1.00e-03
  Batch  20/125 | Loss: 1.3827 | Acc: 27.08% | LR: 1.00e-03
  Batch  40/125 | Loss: 1.5159 | Acc: 26.83% | LR: 1.00e-03
  Batch  60/125 | Loss: 1.3915 | Acc: 24.90% | LR: 1.00e-03
  Batch  80/125 | Loss: 1.4364 | Acc: 24.46% | LR: 1.00e-03
  Batch 100/125 | Loss: 1.4357 | Acc: 23.76% | LR: 1.00e-03
  Batch 120/125 | Loss: 1.3804 | Acc: 24.23% | LR: 1.00e-03


  with torch.cuda.amp.autocast():



📊 Epoch 1 Summary:
  Train: Loss=1.3986, Acc=24.42%
  Val:   Loss=1.4438, Acc=12.10%
  LR: 9.76e-04, Time: 39.0s
  💾 New best model saved! Val Acc: 12.10%

📈 Epoch 2/50
--------------------------------------------------
  Batch   0/125 | Loss: 1.4461 | Acc: 12.50% | LR: 9.76e-04
  Batch  20/125 | Loss: 1.3940 | Acc: 17.86% | LR: 9.76e-04
  Batch  40/125 | Loss: 1.4059 | Acc: 21.34% | LR: 9.76e-04
  Batch  60/125 | Loss: 1.4010 | Acc: 23.36% | LR: 9.76e-04
  Batch  80/125 | Loss: 1.4133 | Acc: 24.77% | LR: 9.76e-04
  Batch 100/125 | Loss: 1.4068 | Acc: 23.45% | LR: 9.76e-04
  Batch 120/125 | Loss: 1.3862 | Acc: 24.43% | LR: 9.76e-04

📊 Epoch 2 Summary:
  Train: Loss=1.3948, Acc=24.42%
  Val:   Loss=1.4042, Acc=12.10%
  LR: 9.05e-04, Time: 38.1s
  ⏳ Patience: 1/15

📈 Epoch 3/50
--------------------------------------------------
  Batch   0/125 | Loss: 1.3915 | Acc: 18.75% | LR: 9.05e-04
  Batch  20/125 | Loss: 1.3860 | Acc: 21.13% | LR: 9.05e-04
  Batch  40/125 | Loss: 1.4380 | Acc: 19.

In [18]:
# =====================================================
# TRAINING STABILITY FIXES FOR AGGRESSIVE MODEL
# =====================================================

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import time
from collections import defaultdict

def train_stable_aggressive_model(model, train_loader, val_loader, class_weights,
                                 epochs=50, initial_lr=5e-4, save_path='stable_aggressive_model.pth'):
    """
    STABLE training with fixes for NaN loss:
    - Gradient clipping
    - Lower learning rate
    - Gradient accumulation
    - Better initialization
    - Loss scaling protection
    - Model parameter monitoring
    """
    print("🔧 STARTING STABLE AGGRESSIVE MODEL TRAINING...")
    print("🛡️ STABILITY FIXES APPLIED:")
    print("• Gradient clipping (max_norm=1.0)")
    print("• Lower learning rate (5e-4 → 1e-4)")
    print("• Gradient accumulation (effective batch size x2)")
    print("• Weight initialization check")
    print("• NaN detection and recovery")
    print("• Mixed precision with loss scaling")

    # ==========================================
    # MODEL STABILITY CHECKS
    # ==========================================

    # Check for NaN in initial weights
    def check_model_weights(model, name=""):
        nan_count = 0
        inf_count = 0
        total_params = 0

        for name, param in model.named_parameters():
            if torch.isnan(param).any():
                print(f"⚠️ NaN detected in {name}")
                nan_count += 1
            if torch.isinf(param).any():
                print(f"⚠️ Inf detected in {name}")
                inf_count += 1
            total_params += param.numel()

        print(f"Model check: {nan_count} NaN params, {inf_count} Inf params, {total_params:,} total")
        return nan_count == 0 and inf_count == 0

    # Initialize model weights properly
    def init_weights(m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight, gain=0.1)  # Smaller gain
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu', a=0.1)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LSTM):
            for name, param in m.named_parameters():
                if 'weight' in name:
                    nn.init.orthogonal_(param, gain=0.1)
                elif 'bias' in name:
                    nn.init.constant_(param, 0)

    print("🔄 Reinitializing model weights with smaller scale...")
    model.apply(init_weights)

    if not check_model_weights(model, "After reinitialization"):
        print("❌ Model has NaN/Inf after initialization!")
        return None

    print("✅ Model weights are stable")

    # ==========================================
    # OPTIMIZER & SCHEDULER SETUP
    # ==========================================

    # Much more conservative optimizer settings
    optimizer = optim.AdamW(
        model.parameters(),
        lr=initial_lr,  # Reduced from 1e-3 to 5e-4
        weight_decay=1e-5,  # Reduced weight decay
        betas=(0.9, 0.98),  # More stable betas
        eps=1e-6  # Larger epsilon
    )

    # More conservative scheduler
    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=20,  # Longer restart period
        T_mult=1,  # No multiplication
        eta_min=1e-7
    )

    # Mixed precision with careful scaling
    scaler = torch.cuda.amp.GradScaler(
        init_scale=2**10,  # Smaller initial scale
        growth_factor=1.1,  # Slower growth
        backoff_factor=0.8,  # More aggressive backoff
        growth_interval=100  # Less frequent growth
    ) if device.type == 'cuda' else None

    # Loss functions with label smoothing for stability
    main_criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
    aux_criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)

    # Training tracking
    history = {
        'train_loss': [], 'train_acc': [], 'train_main_loss': [], 'train_aux_loss': [],
        'val_loss': [], 'val_acc': [], 'val_main_loss': [], 'val_aux_loss': [],
        'lr': [], 'epoch_time': [], 'grad_norm': []
    }

    best_val_acc = 0.0
    patience_counter = 0
    patience = 20  # Increased patience
    accumulation_steps = 2  # Gradient accumulation

    print(f"\n🎯 STABLE Training Configuration:")
    print(f"• Learning rate: {initial_lr} (reduced)")
    print(f"• Weight decay: 1e-5 (reduced)")
    print(f"• Gradient clipping: max_norm=1.0")
    print(f"• Gradient accumulation: {accumulation_steps} steps")
    print(f"• Label smoothing: 0.1")
    print(f"• Patience: {patience}")

    model.to(device)

    for epoch in range(epochs):
        epoch_start_time = time.time()

        # ==========================================
        # TRAINING PHASE WITH STABILITY
        # ==========================================
        model.train()
        train_metrics = defaultdict(float)
        train_correct = 0
        train_total = 0
        grad_norms = []

        print(f"\n📈 Epoch {epoch+1}/{epochs}")
        print("-" * 50)

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            # Mixed precision forward pass
            if scaler:
                with torch.cuda.amp.autocast():
                    main_output, aux_output = model(data)
                    main_loss = main_criterion(main_output, target)
                    aux_loss = aux_criterion(aux_output, target)
                    total_loss = 0.7 * main_loss + 0.3 * aux_loss
                    # Scale loss for gradient accumulation
                    total_loss = total_loss / accumulation_steps

                # Backward pass
                scaler.scale(total_loss).backward()
            else:
                main_output, aux_output = model(data)
                main_loss = main_criterion(main_output, target)
                aux_loss = aux_criterion(aux_output, target)
                total_loss = 0.7 * main_loss + 0.3 * aux_loss
                total_loss = total_loss / accumulation_steps
                total_loss.backward()

            # Gradient accumulation
            if (batch_idx + 1) % accumulation_steps == 0:
                if scaler:
                    # Gradient clipping before optimizer step
                    scaler.unscale_(optimizer)
                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    grad_norms.append(grad_norm.item())

                    # Check for NaN gradients
                    if torch.isnan(grad_norm) or torch.isinf(grad_norm):
                        print(f"⚠️ NaN/Inf gradient detected at batch {batch_idx}, skipping...")
                        scaler.update()
                        optimizer.zero_grad()
                        continue

                    scaler.step(optimizer)
                    scaler.update()
                else:
                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    grad_norms.append(grad_norm.item())

                    if torch.isnan(grad_norm) or torch.isinf(grad_norm):
                        print(f"⚠️ NaN/Inf gradient detected at batch {batch_idx}, skipping...")
                        optimizer.zero_grad()
                        continue

                    optimizer.step()

                optimizer.zero_grad()

            # Track metrics (scale back the loss)
            actual_loss = total_loss.item() * accumulation_steps
            if not (torch.isnan(torch.tensor(actual_loss)) or torch.isinf(torch.tensor(actual_loss))):
                train_metrics['total_loss'] += actual_loss
                train_metrics['main_loss'] += main_loss.item()
                train_metrics['aux_loss'] += aux_loss.item()

                # Accuracy from main output
                _, predicted = torch.max(main_output.data, 1)
                train_total += target.size(0)
                train_correct += (predicted == target).sum().item()

            # Progress update
            if batch_idx % 20 == 0:
                current_lr = optimizer.param_groups[0]['lr']
                avg_grad_norm = np.mean(grad_norms[-10:]) if grad_norms else 0.0
                print(f"  Batch {batch_idx:3d}/{len(train_loader)} | "
                      f"Loss: {actual_loss:.4f} | "
                      f"Acc: {100.*train_correct/train_total:.2f}% | "
                      f"LR: {current_lr:.2e} | "
                      f"GradNorm: {avg_grad_norm:.3f}")

        # ==========================================
        # VALIDATION PHASE
        # ==========================================
        model.eval()
        val_metrics = defaultdict(float)
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)

                if scaler:
                    with torch.cuda.amp.autocast():
                        main_output, aux_output = model(data)
                        main_loss = main_criterion(main_output, target)
                        aux_loss = aux_criterion(aux_output, target)
                        total_loss = 0.7 * main_loss + 0.3 * aux_loss
                else:
                    main_output, aux_output = model(data)
                    main_loss = main_criterion(main_output, target)
                    aux_loss = aux_criterion(aux_output, target)
                    total_loss = 0.7 * main_loss + 0.3 * aux_loss

                # Only add if not NaN
                if not torch.isnan(total_loss):
                    val_metrics['total_loss'] += total_loss.item()
                    val_metrics['main_loss'] += main_loss.item()
                    val_metrics['aux_loss'] += aux_loss.item()

                    _, predicted = torch.max(main_output.data, 1)
                    val_total += target.size(0)
                    val_correct += (predicted == target).sum().item()

        # Calculate epoch metrics
        if len(train_loader) > 0 and train_total > 0:
            train_loss = train_metrics['total_loss'] / len(train_loader)
            train_acc = 100. * train_correct / train_total
        else:
            train_loss = float('inf')
            train_acc = 0.0

        if len(val_loader) > 0 and val_total > 0:
            val_loss = val_metrics['total_loss'] / len(val_loader)
            val_acc = 100. * val_correct / val_total
        else:
            val_loss = float('inf')
            val_acc = 0.0

        # Update learning rate
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']

        # Record history
        if not (torch.isnan(torch.tensor(train_loss)) or torch.isnan(torch.tensor(val_loss))):
            history['train_loss'].append(train_loss)
            history['train_acc'].append(train_acc)
            history['train_main_loss'].append(train_metrics['main_loss'] / len(train_loader))
            history['train_aux_loss'].append(train_metrics['aux_loss'] / len(train_loader))

            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc)
            history['val_main_loss'].append(val_metrics['main_loss'] / len(val_loader))
            history['val_aux_loss'].append(val_metrics['aux_loss'] / len(val_loader))

            history['lr'].append(current_lr)
            history['grad_norm'].append(np.mean(grad_norms) if grad_norms else 0.0)

            epoch_time = time.time() - epoch_start_time
            history['epoch_time'].append(epoch_time)

            # Print epoch summary
            print(f"\n📊 Epoch {epoch+1} Summary:")
            print(f"  Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
            print(f"  Val:   Loss={val_loss:.4f}, Acc={val_acc:.2f}%")
            print(f"  LR: {current_lr:.2e}, Time: {epoch_time:.1f}s")
            print(f"  Avg Grad Norm: {np.mean(grad_norms):.3f}")

            # Save best model
            if val_acc > best_val_acc and not torch.isnan(torch.tensor(val_acc)):
                best_val_acc = val_acc
                patience_counter = 0
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'best_val_acc': best_val_acc,
                    'history': history
                }, save_path)
                print(f"  💾 New best model saved! Val Acc: {val_acc:.2f}%")
            else:
                patience_counter += 1
                print(f"  ⏳ Patience: {patience_counter}/{patience}")

            # Early stopping
            if patience_counter >= patience:
                print(f"\n🛑 Early stopping triggered! Best Val Acc: {best_val_acc:.2f}%")
                break
        else:
            print(f"\n⚠️ NaN loss detected in epoch {epoch+1}, continuing...")
            patience_counter += 1

        # Memory cleanup
        if device.type == 'cuda':
            torch.cuda.empty_cache()

    print(f"\n✅ Stable training completed!")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")
    print(f"Model saved to: {save_path}")

    return history

# Create a fresh model instance with proper initialization
print("🔄 Creating fresh aggressive model with stability fixes...")
stable_aggressive_model = AggressiveCNN_LSTM_Transformer(
    num_classes=4,
    lstm_hidden=384,        # Slightly reduced from 512
    transformer_dim=768,    # Slightly reduced from 1024
    num_heads=12,          # Reduced from 16
    num_layers=6           # Reduced from 8
).to(device)

print("🚀 Starting STABLE training of aggressive model...")
stable_history = train_stable_aggressive_model(
    stable_aggressive_model,
    train_loader,
    val_loader,
    class_weights,
    epochs=50,
    initial_lr=5e-4,  # Much lower learning rate
    save_path='saved_models/stable_aggressive_cnn_lstm_transformer.pth'
)

🔄 Creating fresh aggressive model with stability fixes...
🚀 Building AGGRESSIVE CNN-LSTM-Transformer for A100 40GB...
• Deep CNN feature extraction (6 blocks)
• Large LSTM temporal modeling (hidden: 384)
• Deep Transformer self-attention (dim: 768, heads: 12, layers: 6)
• Multi-scale feature fusion
• Advanced attention mechanisms




✅ AGGRESSIVE CNN-LSTM-Transformer architecture built!
🚀 Starting STABLE training of aggressive model...
🔧 STARTING STABLE AGGRESSIVE MODEL TRAINING...
🛡️ STABILITY FIXES APPLIED:
• Gradient clipping (max_norm=1.0)
• Lower learning rate (5e-4 → 1e-4)
• Gradient accumulation (effective batch size x2)
• Weight initialization check
• NaN detection and recovery
• Mixed precision with loss scaling
🔄 Reinitializing model weights with smaller scale...
Model check: 0 NaN params, 0 Inf params, 106,655,432 total
✅ Model weights are stable

🎯 STABLE Training Configuration:
• Learning rate: 0.0005 (reduced)
• Weight decay: 1e-5 (reduced)
• Gradient clipping: max_norm=1.0
• Gradient accumulation: 2 steps
• Label smoothing: 0.1
• Patience: 20

📈 Epoch 1/50
--------------------------------------------------


  scaler = torch.cuda.amp.GradScaler(
  with torch.cuda.amp.autocast():


  Batch   0/125 | Loss: 1.3962 | Acc: 25.00% | LR: 5.00e-04 | GradNorm: 0.000
  Batch  20/125 | Loss: 1.3916 | Acc: 23.81% | LR: 5.00e-04 | GradNorm: 0.109
  Batch  40/125 | Loss: 1.3934 | Acc: 23.32% | LR: 5.00e-04 | GradNorm: 0.084
  Batch  60/125 | Loss: 1.4074 | Acc: 21.52% | LR: 5.00e-04 | GradNorm: 0.094
  Batch  80/125 | Loss: 1.4174 | Acc: 21.76% | LR: 5.00e-04 | GradNorm: 0.119
  Batch 100/125 | Loss: 1.4079 | Acc: 22.34% | LR: 5.00e-04 | GradNorm: 0.120
  Batch 120/125 | Loss: 1.3913 | Acc: 21.85% | LR: 5.00e-04 | GradNorm: 0.135


  with torch.cuda.amp.autocast():



📊 Epoch 1 Summary:
  Train: Loss=1.3963, Acc=22.16%
  Val:   Loss=1.4146, Acc=12.10%
  LR: 4.97e-04, Time: 39.1s
  Avg Grad Norm: 0.111
  💾 New best model saved! Val Acc: 12.10%

📈 Epoch 2/50
--------------------------------------------------
  Batch   0/125 | Loss: 1.3910 | Acc: 12.50% | LR: 4.97e-04 | GradNorm: 0.000
  Batch  20/125 | Loss: 1.3970 | Acc: 20.83% | LR: 4.97e-04 | GradNorm: 0.135
  Batch  40/125 | Loss: 1.3841 | Acc: 21.04% | LR: 4.97e-04 | GradNorm: 0.096
  Batch  60/125 | Loss: 1.4186 | Acc: 20.39% | LR: 4.97e-04 | GradNorm: 0.100
  Batch  80/125 | Loss: 1.3956 | Acc: 21.30% | LR: 4.97e-04 | GradNorm: 0.109
  Batch 100/125 | Loss: 1.3442 | Acc: 21.97% | LR: 4.97e-04 | GradNorm: 0.198
  Batch 120/125 | Loss: 1.4253 | Acc: 22.00% | LR: 4.97e-04 | GradNorm: 0.197

📊 Epoch 2 Summary:
  Train: Loss=1.3956, Acc=22.10%
  Val:   Loss=1.4278, Acc=12.10%
  LR: 4.88e-04, Time: 39.1s
  Avg Grad Norm: 0.140
  ⏳ Patience: 1/20

📈 Epoch 3/50
----------------------------------------

KeyboardInterrupt: 

In [21]:
# =====================================================
# SIMPLE WORKING MODEL WITH LEAKY RELU
# =====================================================

class SimpleCNN_LSTM_LeakyReLU(nn.Module):
    def __init__(self, num_classes=4, lstm_hidden=128):
        super(SimpleCNN_LSTM_LeakyReLU, self).__init__()

        print("🎯 Building SIMPLE CNN-LSTM with LeakyReLU...")

        # Simple CNN - 3 blocks only (like your working model)
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),  # LeakyReLU as suggested
            nn.Conv2d(32, 32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(0.1),
            nn.MaxPool2d(kernel_size=(2, 2))
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(0.15),
            nn.MaxPool2d(kernel_size=(2, 2))
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(0.2),
            nn.MaxPool2d(kernel_size=(2, 2))
        )

        # Calculate feature size: same as your original model
        self.feature_size = 128 * 16  # 128 channels * 16 remaining pitch bins

        # Simple LSTM
        self.lstm = nn.LSTM(
            input_size=self.feature_size,
            hidden_size=lstm_hidden,
            num_layers=2,
            batch_first=True,
            dropout=0.2,
            bidirectional=True
        )

        # Simple attention
        self.attention = nn.MultiheadAttention(
            embed_dim=lstm_hidden * 2,
            num_heads=4,  # Fewer heads
            dropout=0.2,
            batch_first=True
        )

        # Simple classifier with LeakyReLU
        self.classifier = nn.Sequential(
            nn.Linear(lstm_hidden * 2, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            nn.Linear(128, num_classes)
        )

        print("✅ Simple LeakyReLU architecture built!")

    def forward(self, x):
        batch_size = x.size(0)

        # CNN feature extraction
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)

        # Reshape for LSTM (same as original)
        x = x.permute(0, 3, 1, 2)
        x = x.contiguous().view(batch_size, x.size(1), -1)

        # LSTM processing
        lstm_out, _ = self.lstm(x)

        # Attention
        attn_out, _ = self.attention(lstm_out, lstm_out, lstm_out)

        # Global average pooling
        pooled = torch.mean(attn_out, dim=1)

        # Classification
        output = self.classifier(pooled)

        return output

# Create simple model
simple_model = SimpleCNN_LSTM_LeakyReLU(num_classes=4, lstm_hidden=128).to(device)

# Count parameters
total_params = sum(p.numel() for p in simple_model.parameters())
print(f"📊 Simple Model: {total_params:,} parameters (vs 106M+ in aggressive model)")

# Test forward pass
test_input = torch.randn(4, 1, 128, 4500).to(device)
with torch.no_grad():
    output = simple_model(test_input)
    print(f"Output shape: {output.shape}")
    print(f"✅ Simple model forward pass successful!")

🎯 Building SIMPLE CNN-LSTM with LeakyReLU...
✅ Simple LeakyReLU architecture built!
📊 Simple Model: 3,275,236 parameters (vs 106M+ in aggressive model)
Output shape: torch.Size([4, 4])
✅ Simple model forward pass successful!


In [23]:
def train_simple_model_working():
    """Train the simple model with settings that should work"""

    print("🚀 TRAINING SIMPLE MODEL WITH PROVEN SETTINGS...")

    # Use higher learning rate and simple loss
    optimizer = optim.Adam(simple_model.parameters(), lr=1e-2)  # Higher LR
    criterion = nn.CrossEntropyLoss()  # No class weights initially

    # Train for just 10 epochs to test
    epochs = 10

    for epoch in range(epochs):
        simple_model.train()
        total_loss = 0
        correct = 0
        total = 0

        print(f"\nEpoch {epoch+1}/{epochs}")

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            outputs = simple_model(data)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

            if batch_idx % 20 == 0:
                print(f"  Batch {batch_idx:3d}/{len(train_loader)} | "
                      f"Loss: {loss.item():.4f} | "
                      f"Acc: {100.*correct/total:.2f}%")

        # Validation
        simple_model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                outputs = simple_model(data)
                loss = criterion(outputs, target)

                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += target.size(0)
                val_correct += (predicted == target).sum().item()

        train_acc = 100. * correct / total
        val_acc = 100. * val_correct / val_total

        print(f"📊 Epoch {epoch+1} Summary:")
        print(f"  Train: Loss={total_loss/len(train_loader):.4f}, Acc={train_acc:.2f}%")
        print(f"  Val:   Loss={val_loss/len(val_loader):.4f}, Acc={val_acc:.2f}%")

        # If validation accuracy > 25% (random chance), model is learning!
        if val_acc > 25:
            print(f"  ✅ Model is learning! (Val Acc > random: {val_acc:.2f}%)")
        else:
            print(f"  ⚠️ Still around random chance: {val_acc:.2f}%")

# Run the simple training
train_simple_model_working()

🚀 TRAINING SIMPLE MODEL WITH PROVEN SETTINGS...

Epoch 1/10
  Batch   0/125 | Loss: 1.4000 | Acc: 6.25%
  Batch  20/125 | Loss: 1.4118 | Acc: 25.00%
  Batch  40/125 | Loss: 1.3426 | Acc: 25.46%
  Batch  60/125 | Loss: 1.4429 | Acc: 30.43%
  Batch  80/125 | Loss: 1.3237 | Acc: 33.56%
  Batch 100/125 | Loss: 1.3595 | Acc: 32.43%
  Batch 120/125 | Loss: 1.4348 | Acc: 31.87%
📊 Epoch 1 Summary:
  Train: Loss=1.4955, Acc=31.52%
  Val:   Loss=1.3384, Acc=35.16%
  ✅ Model is learning! (Val Acc > random: 35.16%)

Epoch 2/10
  Batch   0/125 | Loss: 1.4588 | Acc: 18.75%
  Batch  20/125 | Loss: 1.3310 | Acc: 32.74%
  Batch  40/125 | Loss: 1.4169 | Acc: 32.16%
  Batch  60/125 | Loss: 1.4940 | Acc: 32.58%
  Batch  80/125 | Loss: 1.3641 | Acc: 31.64%
  Batch 100/125 | Loss: 1.4821 | Acc: 32.67%
  Batch 120/125 | Loss: 1.2441 | Acc: 35.07%
📊 Epoch 2 Summary:
  Train: Loss=1.3399, Acc=34.94%
  Val:   Loss=1.2627, Acc=46.12%
  ✅ Model is learning! (Val Acc > random: 46.12%)

Epoch 3/10
  Batch   0/125 |

In [24]:
def train_improved_simple_model():
    """Train the simple model with improvements for better performance"""

    print("🚀 TRAINING IMPROVED SIMPLE MODEL...")

    # Use the working settings but with improvements
    optimizer = optim.AdamW(simple_model.parameters(), lr=5e-3, weight_decay=1e-4)  # Slightly lower LR + weight decay

    # Add class weights now that we know it can learn
    criterion = nn.CrossEntropyLoss(weight=class_weights)  # Now use class weights

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=3, verbose=True
    )

    # Train for more epochs with early stopping
    epochs = 25
    best_val_acc = 0.0
    patience_counter = 0
    patience = 8

    for epoch in range(epochs):
        simple_model.train()
        total_loss = 0
        correct = 0
        total = 0

        print(f"\nEpoch {epoch+1}/{epochs}")

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            outputs = simple_model(data)
            loss = criterion(outputs, target)
            loss.backward()

            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(simple_model.parameters(), max_norm=1.0)

            optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

            if batch_idx % 25 == 0:
                print(f"  Batch {batch_idx:3d}/{len(train_loader)} | "
                      f"Loss: {loss.item():.4f} | "
                      f"Acc: {100.*correct/total:.2f}% | "
                      f"LR: {optimizer.param_groups[0]['lr']:.1e}")

        # Validation
        simple_model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                outputs = simple_model(data)
                loss = criterion(outputs, target)

                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += target.size(0)
                val_correct += (predicted == target).sum().item()

        train_acc = 100. * correct / total
        val_acc = 100. * val_correct / val_total

        print(f"📊 Epoch {epoch+1} Summary:")
        print(f"  Train: Loss={total_loss/len(train_loader):.4f}, Acc={train_acc:.2f}%")
        print(f"  Val:   Loss={val_loss/len(val_loader):.4f}, Acc={val_acc:.2f}%")

        # Update learning rate
        scheduler.step(val_acc)

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save({
                'model_state_dict': simple_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'epoch': epoch
            }, 'saved_models/simple_leaky_relu_model.pth')
            print(f"  💾 New best model saved! Val Acc: {val_acc:.2f}%")
        else:
            patience_counter += 1
            print(f"  ⏳ Patience: {patience_counter}/{patience}")

        # Early stopping
        if patience_counter >= patience:
            print(f"\n🛑 Early stopping! Best Val Acc: {best_val_acc:.2f}%")
            break

    print(f"\n✅ Training completed! Best validation accuracy: {best_val_acc:.2f}%")
    return best_val_acc

# Create save directory
os.makedirs('saved_models', exist_ok=True)

# Train the improved model
best_acc = train_improved_simple_model()

🚀 TRAINING IMPROVED SIMPLE MODEL...

Epoch 1/25




  Batch   0/125 | Loss: 1.1525 | Acc: 50.00% | LR: 5.0e-03
  Batch  25/125 | Loss: 0.8272 | Acc: 51.68% | LR: 5.0e-03
  Batch  50/125 | Loss: 1.0390 | Acc: 53.92% | LR: 5.0e-03
  Batch  75/125 | Loss: 0.9505 | Acc: 56.66% | LR: 5.0e-03
  Batch 100/125 | Loss: 1.0458 | Acc: 57.18% | LR: 5.0e-03
📊 Epoch 1 Summary:
  Train: Loss=0.9813, Acc=58.61%
  Val:   Loss=1.1967, Acc=59.13%
  💾 New best model saved! Val Acc: 59.13%

Epoch 2/25
  Batch   0/125 | Loss: 0.7123 | Acc: 75.00% | LR: 5.0e-03
  Batch  25/125 | Loss: 0.8929 | Acc: 63.70% | LR: 5.0e-03
  Batch  50/125 | Loss: 1.1012 | Acc: 63.60% | LR: 5.0e-03
  Batch  75/125 | Loss: 0.7676 | Acc: 66.12% | LR: 5.0e-03
  Batch 100/125 | Loss: 0.9061 | Acc: 64.48% | LR: 5.0e-03
📊 Epoch 2 Summary:
  Train: Loss=0.9057, Acc=63.60%
  Val:   Loss=1.0638, Acc=61.42%
  💾 New best model saved! Val Acc: 61.42%

Epoch 3/25
  Batch   0/125 | Loss: 1.2152 | Acc: 56.25% | LR: 5.0e-03
  Batch  25/125 | Loss: 0.6727 | Acc: 66.59% | LR: 5.0e-03
  Batch  50/12

In [25]:
# =====================================================
# MEMORY CLEANUP AND OPTIMIZATION
# =====================================================

import gc
import torch

def cleanup_memory():
    """Comprehensive memory cleanup"""
    print("🧹 CLEANING UP MEMORY...")

    # List of large models/variables to delete
    models_to_cleanup = [
        'aggressive_model',
        'stable_aggressive_model',
        'model',           # Original CNN_LSTM model
        'rhythm_model',    # Rhythm augmented model
    ]

    # Delete models from global scope
    for model_name in models_to_cleanup:
        if model_name in globals():
            print(f"  Deleting {model_name}...")
            del globals()[model_name]

    # Also clean up any large data variables we don't need
    data_to_cleanup = [
        'improved_data',
        'improved_labels',
        'full_data',
        'full_labels',
        'orig_probs',
        'rhythm_probs',
        'stable_history',
        'history'
    ]

    for data_name in data_to_cleanup:
        if data_name in globals():
            print(f"  Deleting {data_name}...")
            del globals()[data_name]

    # Force garbage collection
    gc.collect()

    # Clear GPU cache if using CUDA
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print(f"  GPU memory cleared")

    # Clear MPS cache if using MPS (Mac)
    if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        torch.mps.empty_cache()
        print(f"  MPS memory cleared")

    print("✅ Memory cleanup completed!")

    # Show current memory usage if possible
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3
        cached = torch.cuda.memory_reserved() / 1024**3
        print(f"📊 GPU Memory: {allocated:.2f}GB allocated, {cached:.2f}GB cached")

# Run cleanup
cleanup_memory()

# Recreate only the simple model (much smaller)
print("\n🎯 Recreating only the simple working model...")
simple_model = SimpleCNN_LSTM_LeakyReLU(num_classes=4, lstm_hidden=128).to(device)

# Count parameters
total_params = sum(p.numel() for p in simple_model.parameters())
print(f"📊 Simple Model: {total_params:,} parameters (~{total_params * 4 / 1024**2:.1f}MB)")

print("\n✅ Memory optimized! Only keeping the working simple model.")

🧹 CLEANING UP MEMORY...
  Deleting aggressive_model...
  Deleting stable_aggressive_model...
  Deleting improved_data...
  Deleting improved_labels...
  Deleting full_data...
  Deleting full_labels...
  Deleting history...
  GPU memory cleared
✅ Memory cleanup completed!
📊 GPU Memory: 2.32GB allocated, 6.91GB cached

🎯 Recreating only the simple working model...
🎯 Building SIMPLE CNN-LSTM with LeakyReLU...
✅ Simple LeakyReLU architecture built!
📊 Simple Model: 3,275,236 parameters (~12.5MB)

✅ Memory optimized! Only keeping the working simple model.


In [26]:
# =====================================================
# ADVANCED DATA AUGMENTATION FOR MUSIC
# =====================================================

import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset

class MusicDataAugmentation:
    """Advanced music-specific data augmentation"""

    def __init__(self, pitch_shift_range=6, time_stretch_range=0.2, noise_level=0.01):
        self.pitch_shift_range = pitch_shift_range  # ±6 semitones
        self.time_stretch_range = time_stretch_range  # ±20% time stretch
        self.noise_level = noise_level

    def pitch_shift(self, piano_roll, shift_amount):
        """Shift piano roll by semitones"""
        if shift_amount == 0:
            return piano_roll

        shifted = torch.zeros_like(piano_roll)
        if shift_amount > 0:
            # Shift up
            shifted[shift_amount:, :] = piano_roll[:-shift_amount, :]
        else:
            # Shift down
            shifted[:shift_amount, :] = piano_roll[-shift_amount:, :]

        return shifted

    def time_stretch(self, piano_roll, stretch_factor):
        """Time stretch using interpolation"""
        if abs(stretch_factor - 1.0) < 0.01:
            return piano_roll

        # Add batch and channel dims for interpolation
        x = piano_roll.unsqueeze(0).unsqueeze(0)  # (1, 1, 128, T)

        # Calculate new width
        new_width = int(x.size(-1) * stretch_factor)

        # Interpolate along time dimension
        stretched = F.interpolate(x, size=(128, new_width), mode='bilinear', align_corners=False)

        # Remove extra dims and crop/pad to original size
        stretched = stretched.squeeze(0).squeeze(0)  # (128, new_width)

        original_width = piano_roll.size(-1)
        if new_width > original_width:
            # Crop from center
            start_idx = (new_width - original_width) // 2
            stretched = stretched[:, start_idx:start_idx + original_width]
        elif new_width < original_width:
            # Pad to original size
            pad_amount = original_width - new_width
            pad_left = pad_amount // 2
            pad_right = pad_amount - pad_left
            stretched = F.pad(stretched, (pad_left, pad_right))

        return stretched

    def add_noise(self, piano_roll, noise_level):
        """Add subtle noise"""
        noise = torch.randn_like(piano_roll) * noise_level
        return torch.clamp(piano_roll + noise, 0, 1)

    def velocity_variation(self, piano_roll, variation=0.1):
        """Vary note velocities"""
        # Only affect non-zero entries
        mask = piano_roll > 0
        variation_factor = 1 + (torch.randn_like(piano_roll) * variation)
        varied = piano_roll * variation_factor
        varied = torch.clamp(varied, 0, 1)
        # Keep zeros as zeros
        varied = varied * mask.float()
        return varied

    def random_mask(self, piano_roll, mask_prob=0.02):
        """Randomly mask some time steps (like SpecAugment)"""
        mask = torch.rand(piano_roll.size(-1)) > mask_prob
        masked = piano_roll.clone()
        masked[:, ~mask] = 0
        return masked

    def __call__(self, piano_roll, apply_prob=0.8):
        """Apply random augmentations"""
        piano_roll = torch.tensor(piano_roll, dtype=torch.float32)

        if torch.rand(1).item() > apply_prob:
            return piano_roll.numpy()

        # Random pitch shift
        if torch.rand(1).item() < 0.5:
            shift = torch.randint(-self.pitch_shift_range, self.pitch_shift_range + 1, (1,)).item()
            piano_roll = self.pitch_shift(piano_roll, shift)

        # Random time stretch
        if torch.rand(1).item() < 0.3:
            stretch = 1 + (torch.rand(1).item() - 0.5) * 2 * self.time_stretch_range
            piano_roll = self.time_stretch(piano_roll, stretch)

        # Random velocity variation
        if torch.rand(1).item() < 0.4:
            piano_roll = self.velocity_variation(piano_roll)

        # Random noise
        if torch.rand(1).item() < 0.3:
            piano_roll = self.add_noise(piano_roll, self.noise_level)

        # Random masking
        if torch.rand(1).item() < 0.2:
            piano_roll = self.random_mask(piano_roll)

        return piano_roll.numpy()

class AugmentedPianoRollDataset(Dataset):
    """Dataset with augmentation"""
    def __init__(self, data, labels, augment_train=True):
        self.data = data
        self.labels = labels
        self.augment_train = augment_train
        self.augmentor = MusicDataAugmentation() if augment_train else None

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        piano_roll = self.data[idx]
        label = self.labels[idx]

        # Apply augmentation during training
        if self.augmentor is not None:
            piano_roll = self.augmentor(piano_roll)

        # Convert to tensor and add channel dimension
        piano_roll = torch.tensor(piano_roll, dtype=torch.float32).unsqueeze(0)
        label = torch.tensor(label, dtype=torch.long)

        return piano_roll, label

# Create augmented datasets
print("🎵 Creating augmented datasets...")
train_dataset_aug = AugmentedPianoRollDataset(
    tracked_data[train_mask],
    tracked_labels[train_mask],
    augment_train=True
)

val_dataset_clean = AugmentedPianoRollDataset(
    tracked_data[val_mask],
    tracked_labels[val_mask],
    augment_train=False  # No augmentation for validation
)

# Create new data loaders
train_loader_aug = DataLoader(train_dataset_aug, batch_size=16, shuffle=True, num_workers=0)
val_loader_clean = DataLoader(val_dataset_clean, batch_size=16, shuffle=False, num_workers=0)

print("✅ Augmented datasets created!")
print(f"• Train samples: {len(train_dataset_aug)} (with augmentation)")
print(f"• Val samples: {len(val_dataset_clean)} (clean)")
print("• Augmentations: Pitch shift, time stretch, velocity variation, noise, masking")

🎵 Creating augmented datasets...
✅ Augmented datasets created!
• Train samples: 1986 (with augmentation)
• Val samples: 438 (clean)
• Augmentations: Pitch shift, time stretch, velocity variation, noise, masking


In [27]:
# =====================================================
# TRAIN SIMPLE MODEL WITH DATA AUGMENTATION
# =====================================================

import torch.optim as optim

def train_simple_model_with_augmentation():
    """Train the simple model with data augmentation for better generalization"""

    print("🎵 TRAINING SIMPLE MODEL WITH DATA AUGMENTATION...")
    print("Expected improvements:")
    print("• Better generalization from pitch shifting")
    print("• Robustness from time stretching")
    print("• Noise tolerance from velocity variation")
    print("• Overfitting reduction from random masking")

    # Use the proven simple model architecture
    aug_model = SimpleCNN_LSTM_LeakyReLU(num_classes=4, lstm_hidden=128).to(device)

    # Optimizer and scheduler
    optimizer = optim.AdamW(aug_model.parameters(), lr=3e-3, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss(weight=class_weights)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.7, patience=4, verbose=True
    )

    # Training setup
    epochs = 30
    best_val_acc = 0.0
    patience_counter = 0
    patience = 10

    print(f"\n🎯 Augmentation Training Configuration:")
    print(f"• Model: Simple CNN-LSTM with LeakyReLU")
    print(f"• Data augmentation: Pitch shift, time stretch, velocity, noise, masking")
    print(f"• Learning rate: 3e-3 with ReduceLROnPlateau")
    print(f"• Class weights: {class_weights}")
    print(f"• Epochs: {epochs} (early stopping patience: {patience})")

    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'lr': []}

    for epoch in range(epochs):
        # ==========================================
        # TRAINING PHASE WITH AUGMENTATION
        # ==========================================
        aug_model.train()
        total_loss = 0
        correct = 0
        total = 0

        print(f"\nEpoch {epoch+1}/{epochs}")

        for batch_idx, (data, target) in enumerate(train_loader_aug):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            outputs = aug_model(data)
            loss = criterion(outputs, target)
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(aug_model.parameters(), max_norm=1.0)

            optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

            if batch_idx % 25 == 0:
                print(f"  Batch {batch_idx:3d}/{len(train_loader_aug)} | "
                      f"Loss: {loss.item():.4f} | "
                      f"Acc: {100.*correct/total:.2f}% | "
                      f"LR: {optimizer.param_groups[0]['lr']:.1e}")

        # ==========================================
        # VALIDATION PHASE (CLEAN DATA)
        # ==========================================
        aug_model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for data, target in val_loader_clean:
                data, target = data.to(device), target.to(device)
                outputs = aug_model(data)
                loss = criterion(outputs, target)

                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += target.size(0)
                val_correct += (predicted == target).sum().item()

        # Calculate metrics
        train_loss = total_loss / len(train_loader_aug)
        train_acc = 100. * correct / total
        val_loss = val_loss / len(val_loader_clean)
        val_acc = 100. * val_correct / val_total

        # Record history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['lr'].append(optimizer.param_groups[0]['lr'])

        print(f"📊 Epoch {epoch+1} Summary:")
        print(f"  Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
        print(f"  Val:   Loss={val_loss:.4f}, Acc={val_acc:.2f}%")

        # Update learning rate
        scheduler.step(val_acc)

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save({
                'model_state_dict': aug_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'epoch': epoch,
                'history': history
            }, 'saved_models/simple_augmented_model.pth')
            print(f"  💾 New best model saved! Val Acc: {val_acc:.2f}%")

            if val_acc > 75.0:  # Great performance threshold
                print(f"  🎉 Excellent performance reached: {val_acc:.2f}%!")
        else:
            patience_counter += 1
            print(f"  ⏳ Patience: {patience_counter}/{patience}")

        # Early stopping
        if patience_counter >= patience:
            print(f"\n🛑 Early stopping! Best Val Acc: {best_val_acc:.2f}%")
            break

    print(f"\n✅ Augmentation training completed!")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")
    print(f"Improvement over baseline: {best_val_acc - 71.92:.2f} percentage points")

    return aug_model, history, best_val_acc

# Run augmented training
print("🚀 Starting training with data augmentation...")
aug_model, aug_history, aug_best_acc = train_simple_model_with_augmentation()

🚀 Starting training with data augmentation...
🎵 TRAINING SIMPLE MODEL WITH DATA AUGMENTATION...
Expected improvements:
• Better generalization from pitch shifting
• Robustness from time stretching
• Noise tolerance from velocity variation
• Overfitting reduction from random masking
🎯 Building SIMPLE CNN-LSTM with LeakyReLU...
✅ Simple LeakyReLU architecture built!

🎯 Augmentation Training Configuration:
• Model: Simple CNN-LSTM with LeakyReLU
• Data augmentation: Pitch shift, time stretch, velocity, noise, masking
• Learning rate: 3e-3 with ReduceLROnPlateau
• Class weights: tensor([0.7927, 0.7847, 1.2697, 1.4781], device='cuda:0')
• Epochs: 30 (early stopping patience: 10)

Epoch 1/30
  Batch   0/125 | Loss: 1.3886 | Acc: 12.50% | LR: 3.0e-03
  Batch  25/125 | Loss: 1.3647 | Acc: 22.36% | LR: 3.0e-03
  Batch  50/125 | Loss: 1.3882 | Acc: 23.41% | LR: 3.0e-03
  Batch  75/125 | Loss: 1.4404 | Acc: 22.94% | LR: 3.0e-03
  Batch 100/125 | Loss: 1.3819 | Acc: 23.21% | LR: 3.0e-03
📊 Epoch 1 

In [28]:
# =====================================================
# SIMPLE CNN-LSTM-TRANSFORMER HYBRID
# =====================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    """Lightweight positional encoding"""
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)

        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class SimpleCNN_LSTM_Transformer(nn.Module):
    def __init__(self, num_classes=4, lstm_hidden=128, transformer_dim=256, num_heads=8, num_layers=2):
        super(SimpleCNN_LSTM_Transformer, self).__init__()

        print("🎯 Building SIMPLE CNN-LSTM-Transformer...")
        print(f"• CNN: 3 blocks (like working model)")
        print(f"• LSTM: {lstm_hidden} hidden units")
        print(f"• Transformer: {transformer_dim} dim, {num_heads} heads, {num_layers} layers")
        print(f"• Designed for stability and performance")

        # ==========================================
        # CNN BACKBONE (SAME AS WORKING MODEL)
        # ==========================================
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(0.1),
            nn.MaxPool2d(kernel_size=(2, 2))
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(0.15),
            nn.MaxPool2d(kernel_size=(2, 2))
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(0.2),
            nn.MaxPool2d(kernel_size=(2, 2))
        )

        # ==========================================
        # LSTM LAYER (SAME AS WORKING MODEL)
        # ==========================================
        self.feature_size = 128 * 16
        self.lstm_hidden = lstm_hidden

        self.lstm = nn.LSTM(
            input_size=self.feature_size,
            hidden_size=lstm_hidden,
            num_layers=2,
            batch_first=True,
            dropout=0.2,
            bidirectional=True
        )

        # ==========================================
        # LIGHTWEIGHT TRANSFORMER
        # ==========================================
        self.transformer_dim = transformer_dim

        # Project LSTM output to transformer dimension
        self.lstm_to_transformer = nn.Linear(lstm_hidden * 2, transformer_dim)

        # Positional encoding
        self.pos_encoding = PositionalEncoding(transformer_dim)

        # Lightweight transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=transformer_dim,
            nhead=num_heads,
            dim_feedforward=transformer_dim * 2,  # Smaller feedforward
            dropout=0.1,
            activation='relu',  # Stick with ReLU for stability
            batch_first=True,
            norm_first=False  # Post-norm for stability
        )

        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers,
            norm=nn.LayerNorm(transformer_dim)
        )

        # ==========================================
        # SIMPLE ATTENTION & CLASSIFICATION
        # ==========================================

        # Simple global attention pooling
        self.attention_pooling = nn.MultiheadAttention(
            embed_dim=transformer_dim,
            num_heads=4,  # Fewer heads for simplicity
            dropout=0.1,
            batch_first=True
        )

        # Simple classification head
        self.classifier = nn.Sequential(
            nn.LayerNorm(transformer_dim),
            nn.Linear(transformer_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            nn.Linear(128, num_classes)
        )

        print("✅ Simple CNN-LSTM-Transformer built!")

    def forward(self, x):
        batch_size = x.size(0)

        # ==========================================
        # CNN FEATURE EXTRACTION
        # ==========================================
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)

        # Reshape for LSTM
        x = x.permute(0, 3, 1, 2)
        x = x.contiguous().view(batch_size, x.size(1), -1)

        # ==========================================
        # LSTM PROCESSING
        # ==========================================
        lstm_out, _ = self.lstm(x)  # (batch, seq_len, lstm_hidden*2)

        # ==========================================
        # TRANSFORMER PROCESSING
        # ==========================================

        # Project to transformer dimension
        transformer_input = self.lstm_to_transformer(lstm_out)  # (batch, seq_len, transformer_dim)

        # Add positional encoding
        seq_len = transformer_input.size(1)
        transformer_input = transformer_input.transpose(0, 1)  # (seq_len, batch, transformer_dim)
        transformer_input = self.pos_encoding(transformer_input)
        transformer_input = transformer_input.transpose(0, 1)  # (batch, seq_len, transformer_dim)

        # Transformer encoding
        transformer_out = self.transformer_encoder(transformer_input)  # (batch, seq_len, transformer_dim)

        # ==========================================
        # ATTENTION POOLING & CLASSIFICATION
        # ==========================================

        # Global attention pooling
        attended, _ = self.attention_pooling(
            transformer_out, transformer_out, transformer_out
        )

        # Global average pooling
        pooled = torch.mean(attended, dim=1)  # (batch, transformer_dim)

        # Classification
        output = self.classifier(pooled)

        return output

# Create the simple transformer model
simple_transformer_model = SimpleCNN_LSTM_Transformer(
    num_classes=4,
    lstm_hidden=128,
    transformer_dim=256,  # Reasonable size
    num_heads=8,         # Moderate number of heads
    num_layers=2         # Just 2 layers for stability
).to(device)

# Count parameters
total_params = sum(p.numel() for p in simple_transformer_model.parameters())
print(f"📊 Simple Transformer Model: {total_params:,} parameters")

# Test forward pass
test_input = torch.randn(4, 1, 128, 4500).to(device)
with torch.no_grad():
    output = simple_transformer_model(test_input)
    print(f"Output shape: {output.shape}")
    print(f"✅ Simple transformer model forward pass successful!")

🎯 Building SIMPLE CNN-LSTM-Transformer...
• CNN: 3 blocks (like working model)
• LSTM: 128 hidden units
• Transformer: 256 dim, 8 heads, 2 layers
• Designed for stability and performance
✅ Simple CNN-LSTM-Transformer built!
📊 Simple Transformer Model: 4,396,260 parameters
Output shape: torch.Size([4, 4])
✅ Simple transformer model forward pass successful!


In [29]:
# =====================================================
# SIMPLE CNN-LSTM-TRANSFORMER HYBRID
# =====================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    """Lightweight positional encoding"""
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)

        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class SimpleCNN_LSTM_Transformer(nn.Module):
    def __init__(self, num_classes=4, lstm_hidden=128, transformer_dim=256, num_heads=8, num_layers=2):
        super(SimpleCNN_LSTM_Transformer, self).__init__()

        print("🎯 Building SIMPLE CNN-LSTM-Transformer...")
        print(f"• CNN: 3 blocks (like working model)")
        print(f"• LSTM: {lstm_hidden} hidden units")
        print(f"• Transformer: {transformer_dim} dim, {num_heads} heads, {num_layers} layers")
        print(f"• Designed for stability and performance")

        # ==========================================
        # CNN BACKBONE (SAME AS WORKING MODEL)
        # ==========================================
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(0.1),
            nn.MaxPool2d(kernel_size=(2, 2))
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(0.15),
            nn.MaxPool2d(kernel_size=(2, 2))
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(0.2),
            nn.MaxPool2d(kernel_size=(2, 2))
        )

        # ==========================================
        # LSTM LAYER (SAME AS WORKING MODEL)
        # ==========================================
        self.feature_size = 128 * 16
        self.lstm_hidden = lstm_hidden

        self.lstm = nn.LSTM(
            input_size=self.feature_size,
            hidden_size=lstm_hidden,
            num_layers=2,
            batch_first=True,
            dropout=0.2,
            bidirectional=True
        )

        # ==========================================
        # LIGHTWEIGHT TRANSFORMER
        # ==========================================
        self.transformer_dim = transformer_dim

        # Project LSTM output to transformer dimension
        self.lstm_to_transformer = nn.Linear(lstm_hidden * 2, transformer_dim)

        # Positional encoding
        self.pos_encoding = PositionalEncoding(transformer_dim)

        # Lightweight transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=transformer_dim,
            nhead=num_heads,
            dim_feedforward=transformer_dim * 2,  # Smaller feedforward
            dropout=0.1,
            activation='relu',  # Stick with ReLU for stability
            batch_first=True,
            norm_first=False  # Post-norm for stability
        )

        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers,
            norm=nn.LayerNorm(transformer_dim)
        )

        # ==========================================
        # SIMPLE ATTENTION & CLASSIFICATION
        # ==========================================

        # Simple global attention pooling
        self.attention_pooling = nn.MultiheadAttention(
            embed_dim=transformer_dim,
            num_heads=4,  # Fewer heads for simplicity
            dropout=0.1,
            batch_first=True
        )

        # Simple classification head
        self.classifier = nn.Sequential(
            nn.LayerNorm(transformer_dim),
            nn.Linear(transformer_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            nn.Linear(128, num_classes)
        )

        print("✅ Simple CNN-LSTM-Transformer built!")

    def forward(self, x):
        batch_size = x.size(0)

        # ==========================================
        # CNN FEATURE EXTRACTION
        # ==========================================
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)

        # Reshape for LSTM
        x = x.permute(0, 3, 1, 2)
        x = x.contiguous().view(batch_size, x.size(1), -1)

        # ==========================================
        # LSTM PROCESSING
        # ==========================================
        lstm_out, _ = self.lstm(x)  # (batch, seq_len, lstm_hidden*2)

        # ==========================================
        # TRANSFORMER PROCESSING
        # ==========================================

        # Project to transformer dimension
        transformer_input = self.lstm_to_transformer(lstm_out)  # (batch, seq_len, transformer_dim)

        # Add positional encoding
        seq_len = transformer_input.size(1)
        transformer_input = transformer_input.transpose(0, 1)  # (seq_len, batch, transformer_dim)
        transformer_input = self.pos_encoding(transformer_input)
        transformer_input = transformer_input.transpose(0, 1)  # (batch, seq_len, transformer_dim)

        # Transformer encoding
        transformer_out = self.transformer_encoder(transformer_input)  # (batch, seq_len, transformer_dim)

        # ==========================================
        # ATTENTION POOLING & CLASSIFICATION
        # ==========================================

        # Global attention pooling
        attended, _ = self.attention_pooling(
            transformer_out, transformer_out, transformer_out
        )

        # Global average pooling
        pooled = torch.mean(attended, dim=1)  # (batch, transformer_dim)

        # Classification
        output = self.classifier(pooled)

        return output

# Create the simple transformer model
simple_transformer_model = SimpleCNN_LSTM_Transformer(
    num_classes=4,
    lstm_hidden=128,
    transformer_dim=256,  # Reasonable size
    num_heads=8,         # Moderate number of heads
    num_layers=2         # Just 2 layers for stability
).to(device)

# Count parameters
total_params = sum(p.numel() for p in simple_transformer_model.parameters())
print(f"📊 Simple Transformer Model: {total_params:,} parameters")

# Test forward pass
test_input = torch.randn(4, 1, 128, 4500).to(device)
with torch.no_grad():
    output = simple_transformer_model(test_input)
    print(f"Output shape: {output.shape}")
    print(f"✅ Simple transformer model forward pass successful!")

🎯 Building SIMPLE CNN-LSTM-Transformer...
• CNN: 3 blocks (like working model)
• LSTM: 128 hidden units
• Transformer: 256 dim, 8 heads, 2 layers
• Designed for stability and performance
✅ Simple CNN-LSTM-Transformer built!
📊 Simple Transformer Model: 4,396,260 parameters
Output shape: torch.Size([4, 4])
✅ Simple transformer model forward pass successful!


In [30]:
# =====================================================
# TRAIN SIMPLE CNN-LSTM-TRANSFORMER
# =====================================================

def train_simple_transformer_model():
    """Train the simple transformer model"""

    print("🚀 TRAINING SIMPLE CNN-LSTM-TRANSFORMER...")
    print("Expected benefits:")
    print("• Better long-range dependencies from transformer")
    print("• Improved musical pattern recognition")
    print("• Stable training with proven CNN-LSTM base")

    # Conservative training settings
    optimizer = optim.AdamW(
        simple_transformer_model.parameters(),
        lr=2e-3,  # Slightly lower LR for transformer stability
        weight_decay=1e-4
    )

    criterion = nn.CrossEntropyLoss(weight=class_weights)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.7, patience=5, verbose=True
    )

    epochs = 30
    best_val_acc = 0.0
    patience_counter = 0
    patience = 12

    print(f"\n🎯 Simple Transformer Training Configuration:")
    print(f"• Learning rate: 2e-3 (conservative for transformer)")
    print(f"• Gradient clipping: max_norm=1.0")
    print(f"• Epochs: {epochs} (patience: {patience})")
    print(f"• Using augmented training data")

    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'lr': []}

    for epoch in range(epochs):
        # ==========================================
        # TRAINING PHASE
        # ==========================================
        simple_transformer_model.train()
        total_loss = 0
        correct = 0
        total = 0

        print(f"\nEpoch {epoch+1}/{epochs}")

        for batch_idx, (data, target) in enumerate(train_loader_aug):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            outputs = simple_transformer_model(data)
            loss = criterion(outputs, target)
            loss.backward()

            # Gradient clipping (important for transformers)
            torch.nn.utils.clip_grad_norm_(simple_transformer_model.parameters(), max_norm=1.0)

            optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

            if batch_idx % 25 == 0:
                print(f"  Batch {batch_idx:3d}/{len(train_loader_aug)} | "
                      f"Loss: {loss.item():.4f} | "
                      f"Acc: {100.*correct/total:.2f}% | "
                      f"LR: {optimizer.param_groups[0]['lr']:.1e}")

        # ==========================================
        # VALIDATION PHASE
        # ==========================================
        simple_transformer_model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for data, target in val_loader_clean:
                data, target = data.to(device), target.to(device)
                outputs = simple_transformer_model(data)
                loss = criterion(outputs, target)

                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += target.size(0)
                val_correct += (predicted == target).sum().item()

        # Calculate metrics
        train_loss = total_loss / len(train_loader_aug)
        train_acc = 100. * correct / total
        val_loss = val_loss / len(val_loader_clean)
        val_acc = 100. * val_correct / val_total

        # Record history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['lr'].append(optimizer.param_groups[0]['lr'])

        print(f"📊 Epoch {epoch+1} Summary:")
        print(f"  Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
        print(f"  Val:   Loss={val_loss:.4f}, Acc={val_acc:.2f}%")

        # Update learning rate
        scheduler.step(val_acc)

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save({
                'model_state_dict': simple_transformer_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'epoch': epoch,
                'history': history
            }, 'saved_models/simple_transformer_model.pth')
            print(f"  💾 New best model saved! Val Acc: {val_acc:.2f}%")

            if val_acc > 75.0:
                print(f"  🎉 Excellent performance reached: {val_acc:.2f}%!")
        else:
            patience_counter += 1
            print(f"  ⏳ Patience: {patience_counter}/{patience}")

        # Early stopping
        if patience_counter >= patience:
            print(f"\n🛑 Early stopping! Best Val Acc: {best_val_acc:.2f}%")
            break

    print(f"\n✅ Simple transformer training completed!")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")
    print(f"Improvement over baseline: {best_val_acc - 71.92:.2f} percentage points")

    return simple_transformer_model, history, best_val_acc

# Run transformer training
print("🚀 Starting training of Simple CNN-LSTM-Transformer...")
transformer_model, transformer_history, transformer_best_acc = train_simple_transformer_model()

🚀 Starting training of Simple CNN-LSTM-Transformer...
🚀 TRAINING SIMPLE CNN-LSTM-TRANSFORMER...
Expected benefits:
• Better long-range dependencies from transformer
• Improved musical pattern recognition
• Stable training with proven CNN-LSTM base

🎯 Simple Transformer Training Configuration:
• Learning rate: 2e-3 (conservative for transformer)
• Gradient clipping: max_norm=1.0
• Epochs: 30 (patience: 12)
• Using augmented training data

Epoch 1/30
  Batch   0/125 | Loss: 1.4834 | Acc: 12.50% | LR: 2.0e-03
  Batch  25/125 | Loss: 1.3548 | Acc: 22.60% | LR: 2.0e-03
  Batch  50/125 | Loss: 1.4184 | Acc: 22.92% | LR: 2.0e-03
  Batch  75/125 | Loss: 1.4112 | Acc: 22.37% | LR: 2.0e-03
  Batch 100/125 | Loss: 1.4195 | Acc: 24.44% | LR: 2.0e-03
📊 Epoch 1 Summary:
  Train: Loss=1.4184, Acc=23.92%
  Val:   Loss=1.4007, Acc=35.16%
  💾 New best model saved! Val Acc: 35.16%

Epoch 2/30
  Batch   0/125 | Loss: 1.4570 | Acc: 12.50% | LR: 2.0e-03
  Batch  25/125 | Loss: 1.3635 | Acc: 27.40% | LR: 2.0

KeyboardInterrupt: 