In [None]:
import os
import kagglehub
import zipfile
import shutil
import numpy as np
import torch
import torch.nn as nn

In [43]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [2]:
path = kagglehub.dataset_download("blanderbuss/midi-classic-music")

print("Path to dataset files:", path)

Path to dataset files: /Users/arr/.cache/kagglehub/datasets/blanderbuss/midi-classic-music/versions/1


In [3]:
# List all files and directories in the downloaded dataset path
print("Files and directories in dataset path:")
for item in os.listdir(path):
    print(item)

Files and directories in dataset path:
Tchaikovsky Lake Of The Swans Act 1 6mov.mid
Arndt
Rothchild Symphony Rmw12 2mov.mid
Tchaicovsky Waltz of the Flowers.MID
Tchaikovsky Lake Of The Swans Act 2 14mov.mid
Tchaikovsky Lake Of The Swans Act 1 4mov.mid
Albe╠üniz
Tchaikovsky Lake Of The Swans Act 2 10mov.mid
Tchaikovsky Lake Of The Swans Act 1 2mov.mid
midiclassics
Tchaikovsky Lake Of The Swans Act 2 12mov.mid
Alkan
Rothchlid Symphony Rmw12 3mov.mid
Tchaikovsky Lake Of The Swans Act 1 7-8movs.mid
Sibelius Kuolema Vals op44.mid
Wagner Ride of the valkyries.mid
Tchaikovsky Lake Of The Swans Act 1 5mov.mid
Tchaikovsky Lake Of The Swans Act 1 9mov.mid
Tchaikovsky Lake Of The Swans Act 1 1mov.mid
Arensky
Tchaikovsky Lake Of The Swans Act 2 11mov.mid
Tchaikovsky Lake Of The Swans Act 2 13mov.mid
Tchaikovsky Lake Of The Swans Act 1 3mov.mid
Ambroise
midiclassics.zip


In [None]:
# here, we'll list the directories we have in the manually downloaded dataset in 'data/NN_midi_files_extended/dev'
directories = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
print("Directories':")
for d in directories:
    print(d)

Directories in 'data/NN_midi_files_extended/dev':
mozart
chopin
handel
byrd
schumann
mendelssohn
hummel
bach
bartok


In [5]:
zip_path = os.path.join(path, 'midiclassics.zip')
extract_path = os.path.join('data', 'kaggle', 'midiclassics')
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)
print("Extracted files to:", extract_path)

Extracted files to: data/kaggle/midiclassics


In [6]:
print("Files and directories in extracted folder:")
for item in os.listdir(extract_path):
    print(item)

Files and directories in extracted folder:
Griffes
Mozart
Durand, E
Satie
Rothchild Piano Sonata Rmw13 2mov.mid
Liszt Bach Prelude Transcription.mid
Diabelli Sonatina op151 n1 2mov.mid
Liszt Paganini Etude n5.mid
Tchaikovsky Lake Of The Swans Act 1 6mov.mid
Arndt
Rothchild Symphony Rmw12 2mov.mid
Skriabin
Ginastera Estancia.mid
Bizet Carmen Prelude.mid
Rothchild Horn Concerto Rmw16 3mov.mid
Jakobowski
Chopin
Kuhlau Sonatina op55 n3 1mov.mid
Stravinski
Taube
Komzak
Lange
Mendelsonn
Tchaicovsky Waltz of the Flowers.MID
Reinecke Piano Concerto n3 1mov.mid
Katzwarra
Diabelli Sonatina op151 n2 1mov.mid
Vaughan
Diabelli Sonatina op151 n3 1mov.mid
Pachelbel
Coleridge-Taylor
Rossini
Czerny
Ravel
Buxethude Buxwv138 Prelude.mid
Finck
Durand, MA
Handel
Hiller
Rothchild Horn Concerto Rmw16 1mov.mid
Liszt Paganini Etude n3.mid
.DS_Store
Copland
Burgmuller
Liszt Paganini Etude n2.mid
Debussy Suite Bergamasque 2mov.mid
Tchaikovsky Lake Of The Swans Act 2 14mov.mid
MacBeth
Dvorak Symphony op70 n7 2mov

In [7]:
TARGET_COMPOSERS = [
    'Bach',
    'Beethoven',
    'Chopin',
    'Mozart',
]

In [8]:
# list files in extract_path that contain the target composers in name
for composer in TARGET_COMPOSERS:
    composer_files = [f for f in os.listdir(extract_path) if composer.lower() in f.lower()]
    print(f"Files for {composer}: {composer_files}")

Files for Bach: ['Liszt Bach Prelude Transcription.mid', 'midi_bach_flat', 'Bach', 'C.P.E.Bach Solfeggieto.mid']
Files for Beethoven: ['Beethoven', 'midi_beethoven_flat']
Files for Chopin: ['Chopin', 'midi_chopin_flat']
Files for Mozart: ['Mozart', 'midi_mozart_flat']


In [9]:
# Only keep directories that contain a target composer's name
for item in os.listdir(extract_path):
    item_path = os.path.join(extract_path, item)
    if not any(composer.lower() in item.lower() for composer in TARGET_COMPOSERS):
        if os.path.isfile(item_path):
            os.remove(item_path)
            print(f"Deleted file: {item_path}")
        elif os.path.isdir(item_path):
            shutil.rmtree(item_path)
            print(f"Deleted directory: {item_path}")

Deleted directory: data/kaggle/midiclassics/Griffes
Deleted directory: data/kaggle/midiclassics/Durand, E
Deleted directory: data/kaggle/midiclassics/Satie
Deleted file: data/kaggle/midiclassics/Rothchild Piano Sonata Rmw13 2mov.mid
Deleted file: data/kaggle/midiclassics/Diabelli Sonatina op151 n1 2mov.mid
Deleted file: data/kaggle/midiclassics/Liszt Paganini Etude n5.mid
Deleted file: data/kaggle/midiclassics/Tchaikovsky Lake Of The Swans Act 1 6mov.mid
Deleted directory: data/kaggle/midiclassics/Arndt
Deleted file: data/kaggle/midiclassics/Rothchild Symphony Rmw12 2mov.mid
Deleted directory: data/kaggle/midiclassics/Skriabin
Deleted file: data/kaggle/midiclassics/Ginastera Estancia.mid
Deleted file: data/kaggle/midiclassics/Bizet Carmen Prelude.mid
Deleted file: data/kaggle/midiclassics/Rothchild Horn Concerto Rmw16 3mov.mid
Deleted directory: data/kaggle/midiclassics/Jakobowski
Deleted file: data/kaggle/midiclassics/Kuhlau Sonatina op55 n3 1mov.mid
Deleted directory: data/kaggle/mid

In [10]:
# also delete "C.P.E.Bach" files. This was the son of J.S. Bach, and we want to keep only the main composers
for item in os.listdir(extract_path):
    if 'C.P.E.Bach' in item:
        item_path = os.path.join(extract_path, item)
        if os.path.isfile(item_path):
            os.remove(item_path)
            print(f"Deleted file: {item_path}")
        elif os.path.isdir(item_path):
            shutil.rmtree(item_path)
            print(f"Deleted directory: {item_path}")

Deleted file: data/kaggle/midiclassics/C.P.E.Bach Solfeggieto.mid


In [None]:
%pip install pretty_midi

In [None]:
import os
import numpy as np
import pretty_midi

In [None]:
import torch
from torch.utils.data import Dataset
import random

class PianoRollDataset(Dataset):
    def __init__(self, data, labels, augment=False):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
        self.augment = augment
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # Get the original data
        piano_roll = self.data[idx].clone()
        label = self.labels[idx]
        
        # Apply augmentations during training
        if self.augment:
            piano_roll = self._augment_piano_roll(piano_roll)
        
        # Add channel dimension for CNN: (1, 128, T)
        return piano_roll.unsqueeze(0), label
    
    def _augment_piano_roll(self, piano_roll):
        """Apply various augmentation techniques to piano roll"""
        
        # 1. Pitch shift (transpose) - shift piano roll up/down
        if random.random() < 0.4:
            shift = random.randint(-6, 6)  # Shift by up to 6 semitones
            if shift > 0:
                piano_roll = torch.cat([torch.zeros(shift, piano_roll.size(1)), piano_roll[:-shift]], dim=0)
            elif shift < 0:
                piano_roll = torch.cat([piano_roll[-shift:], torch.zeros(-shift, piano_roll.size(1))], dim=0)
        
        # 2. Time stretching - slightly change tempo
        if random.random() < 0.3:
            stretch_factor = random.uniform(0.9, 1.1)
            original_length = piano_roll.size(1)
            new_length = int(original_length * stretch_factor)
            
            if new_length != original_length:
                # Simple interpolation for time stretching
                indices = torch.linspace(0, original_length - 1, new_length).long()
                piano_roll = piano_roll[:, indices]
                
                # Pad or trim to original length
                if new_length < original_length:
                    pad_width = original_length - new_length
                    piano_roll = torch.cat([piano_roll, torch.zeros(128, pad_width)], dim=1)
                else:
                    piano_roll = piano_roll[:, :original_length]
        
        # 3. Velocity variation - change note intensities
        if random.random() < 0.5:
            velocity_factor = random.uniform(0.8, 1.2)
            piano_roll = piano_roll * velocity_factor
        
        # 4. Add slight noise
        if random.random() < 0.3:
            noise = torch.randn_like(piano_roll) * 0.02
            piano_roll = piano_roll + noise
        
        # 5. Random note dropout - simulate missed notes
        if random.random() < 0.2:
            dropout_mask = torch.rand_like(piano_roll) > 0.05  # Drop 5% of notes
            piano_roll = piano_roll * dropout_mask.float()
        
        # Ensure values stay in reasonable range
        piano_roll = torch.clamp(piano_roll, 0, 1)
        
        return piano_roll

In [None]:
def get_piano_roll(midi_path, fs=100, max_length=3000):  # Reduced from 5000 to 3000
    """Convert MIDI file to piano roll representation"""
    pm = pretty_midi.PrettyMIDI(midi_path)
    piano_roll = pm.get_piano_roll(fs=fs)
    # Truncate or pad to fixed length
    if piano_roll.shape[1] > max_length:
        piano_roll = piano_roll[:, :max_length]
    else:
        pad_width = max_length - piano_roll.shape[1]
        piano_roll = np.pad(piano_roll, ((0,0),(0,pad_width)), mode='constant')
    return piano_roll

In [None]:
import os
import numpy as np

# Load all MIDI files and convert to piano rolls
extract_path = os.path.join('data', 'kaggle', 'midiclassics')
base_dir = extract_path
target_composers = ['Bach', 'Beethoven', 'Chopin', 'Mozart']
composer_to_idx = {c: i for i, c in enumerate(target_composers)}

# Initialize empty lists
all_data = []
all_labels = []

print("Loading MIDI files one composer at a time...")

for composer in target_composers:
    print(f"\n--- Processing {composer} ---")
    composer_dir = os.path.join(base_dir, composer)
    
    if not os.path.isdir(composer_dir):
        print(f"Directory not found: {composer_dir}")
        continue
    
    # Process this composer's files
    composer_data = []
    composer_labels = []
    files_processed = 0
    
    for file in os.listdir(composer_dir):
        if file.lower().endswith('.mid') or file.lower().endswith('.midi'):
            midi_path = os.path.join(composer_dir, file)
            try:
                piano_roll = get_piano_roll(midi_path)
                composer_data.append(piano_roll)
                composer_labels.append(composer_to_idx[composer])
                files_processed += 1
                
                if files_processed % 20 == 0:  # Progress indicator
                    print(f"  Processed {files_processed} files...")
                    
            except Exception as e:
                print(f"  Error processing {midi_path}: {e}")
    
    print(f"Loaded {files_processed} files for {composer}")
    
    # Convert to numpy and append to main lists
    if composer_data:
        composer_data = np.array(composer_data)
        composer_labels = np.array(composer_labels)
        
        all_data.append(composer_data)
        all_labels.append(composer_labels)
        
        print(f"  {composer} data shape: {composer_data.shape}")
        
        # Clear memory
        del composer_data, composer_labels

# Combine all data
print("\nCombining all data...")
data = np.concatenate(all_data, axis=0)
labels = np.concatenate(all_labels, axis=0)

print(f"Final dataset shape: {data.shape}")
print(f"Final labels shape: {labels.shape}")
print(f"Composer mapping: {composer_to_idx}")

# Clear intermediate data
del all_data, all_labels

In [None]:
from sklearn.model_selection import train_test_split

# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(
    data, labels, test_size=0.2, stratify=labels, random_state=42
)

print(f"Training set: {X_train.shape[0]} samples")
print(f"Test set: {X_test.shape[0]} samples")
print(f"Training labels distribution: {np.bincount(y_train)}")
print(f"Test labels distribution: {np.bincount(y_test)}")

In [None]:
from torch.utils.data import DataLoader, WeightedRandomSampler

# Create datasets with augmentation for training
train_dataset = PianoRollDataset(X_train, y_train, augment=True)  # Enable augmentation for training
test_dataset = PianoRollDataset(X_test, y_test, augment=False)   # No augmentation for testing

# Calculate class weights for balanced sampling
class_counts = np.bincount(y_train)
class_weights = 1.0 / class_counts
sample_weights = class_weights[y_train]

# Create weighted sampler to handle class imbalance
sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

# Create data loaders with balanced sampling
train_loader = DataLoader(
    train_dataset, 
    batch_size=16, 
    sampler=sampler,  # Use weighted sampler instead of shuffle
    num_workers=2,    # Speed up data loading
    pin_memory=True   # Faster GPU transfer
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=16, 
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"Train loader: {len(train_loader)} batches")
print(f"Test loader: {len(test_loader)} batches")
print(f"Class distribution: {dict(zip(['Bach', 'Beethoven', 'Chopin', 'Mozart'], class_counts))}")
print(f"Class weights: {dict(zip(['Bach', 'Beethoven', 'Chopin', 'Mozart'], class_weights))}")

In [None]:
class CNN_LSTM_Classifier(nn.Module):
    def __init__(self, num_classes=4, lstm_hidden=128):
        super(CNN_LSTM_Classifier, self).__init__()
        
        # Enhanced CNN with deeper architecture and batch normalization
        self.cnn = nn.Sequential(
            # First CNN block
            nn.Conv2d(1, 16, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.MaxPool2d(kernel_size=(1, 2)),
            
            # Second CNN block
            nn.Conv2d(16, 32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.MaxPool2d(kernel_size=(1, 2)),
            
            # Third CNN block - NEW for better pattern extraction
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout2d(0.3),
            nn.MaxPool2d(kernel_size=(1, 2)),
            
            # Fourth CNN block - NEW for complex musical patterns
            nn.Conv2d(64, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout2d(0.3),
            nn.MaxPool2d(kernel_size=(1, 2)),
        )
        
        # Calculate LSTM input size: 128 channels * 128 keys = 16384
        # After 4 pooling layers (2^4 = 16): 3000/16 = 187.5 ≈ 187 time steps
        self.lstm_input_size = 128 * 128
        
        # Bidirectional LSTM for better temporal understanding
        self.lstm = nn.LSTM(
            input_size=self.lstm_input_size,
            hidden_size=lstm_hidden,
            num_layers=2,  # Increased from 1 to 2
            batch_first=True,
            dropout=0.3,
            bidirectional=True  # NEW: bidirectional for better context
        )
        
        # Attention mechanism for better feature selection
        self.attention = nn.MultiheadAttention(
            embed_dim=lstm_hidden * 2,  # *2 for bidirectional
            num_heads=8,
            dropout=0.3,
            batch_first=True
        )
        
        self.dropout = nn.Dropout(0.5)
        
        # Enhanced classifier with intermediate layer
        self.classifier = nn.Sequential(
            nn.Linear(lstm_hidden * 2, 256),  # *2 for bidirectional
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        # x shape: (batch, 1, 128, 3000)
        batch_size = x.size(0)
        
        # CNN feature extraction
        x = self.cnn(x)  # (batch, 128, 128, ~187)
        
        # Reshape for LSTM: (batch, time_steps, features)
        x = x.permute(0, 3, 1, 2)  # (batch, ~187, 128, 128)
        x = x.contiguous().view(batch_size, x.size(1), -1)  # (batch, ~187, 16384)
        
        # Bidirectional LSTM processing
        lstm_out, _ = self.lstm(x)  # (batch, ~187, 256)
        
        # Self-attention for better feature selection
        attn_out, _ = self.attention(lstm_out, lstm_out, lstm_out)
        
        # Global average pooling over time dimension
        pooled = torch.mean(attn_out, dim=1)  # (batch, 256)
        
        # Final classification
        x = self.dropout(pooled)
        x = self.classifier(x)  # (batch, 4)
        
        return x

In [None]:
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import torch.nn.functional as F

# Initialize the enhanced model
model = CNN_LSTM_Classifier(num_classes=4, lstm_hidden=128).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Enhanced model with deeper CNN, bidirectional LSTM, and attention")

# Use class weights for loss function to handle imbalance
class_weights = torch.FloatTensor([1.0 / count for count in class_counts]).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Use AdamW optimizer (generally better than Adam)
optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=0.001,  # Slightly higher initial learning rate
    weight_decay=1e-3,  # Increased weight decay for better regularization
    betas=(0.9, 0.999)
)

# Use Cosine Annealing with Warm Restarts for better convergence
scheduler = CosineAnnealingWarmRestarts(
    optimizer, 
    T_0=10,  # Restart every 10 epochs
    T_mult=2,  # Double the restart period each time
    eta_min=1e-6  # Minimum learning rate
)

print(f"Using weighted loss function with weights: {class_weights.cpu().numpy()}")
print(f"Using CosineAnnealingWarmRestarts scheduler")

In [None]:
def train_model_enhanced(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs=25):
    model.train()
    train_losses = []
    train_accuracies = []
    val_accuracies = []
    best_val_acc = 0.0
    patience = 7
    patience_counter = 0
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        correct = 0
        total = 0
        
        # Training phase
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            output = model(data)
            loss = criterion(output, target)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            # Clear cache periodically
            if batch_idx % 10 == 0:
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
            
            epoch_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            if batch_idx % 10 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx+1}/{len(train_loader)}, Loss: {loss.item():.4f}')
        
        # Calculate training metrics
        avg_loss = epoch_loss / len(train_loader)
        train_accuracy = 100 * correct / total
        train_losses.append(avg_loss)
        train_accuracies.append(train_accuracy)
        
        # Validation phase
        model.eval()
        val_correct = 0
        val_total = 0
        val_loss = 0.0
        
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                val_loss += criterion(output, target).item()
                _, predicted = torch.max(output.data, 1)
                val_total += target.size(0)
                val_correct += (predicted == target).sum().item()
        
        val_accuracy = 100 * val_correct / val_total
        val_avg_loss = val_loss / len(test_loader)
        val_accuracies.append(val_accuracy)
        
        # Step the scheduler
        scheduler.step()
        
        print(f'Epoch {epoch+1}/{epochs} Complete:')
        print(f'  Train - Loss: {avg_loss:.4f}, Accuracy: {train_accuracy:.2f}%')
        print(f'  Val   - Loss: {val_avg_loss:.4f}, Accuracy: {val_accuracy:.2f}%')
        print(f'  LR: {scheduler.get_last_lr()[0]:.6f}')
        print('-' * 50)
        
        # Early stopping and model saving
        if val_accuracy > best_val_acc:
            best_val_acc = val_accuracy
            patience_counter = 0
            # Save best model
            torch.save(model.state_dict(), 'best_model.pth')
            print(f'  *** New best validation accuracy: {best_val_acc:.2f}% - Model saved! ***')
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            print(f'Early stopping triggered after {patience} epochs without improvement')
            break
        
        # Clear cache after each epoch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    # Load best model
    model.load_state_dict(torch.load('best_model.pth'))
    print(f"\nTraining completed! Best validation accuracy: {best_val_acc:.2f}%")
    
    return train_losses, train_accuracies, val_accuracies

In [None]:
# Train the enhanced model
print("Starting enhanced training with improved architecture and data augmentation...")
print("Key improvements:")
print("- Deeper CNN with batch normalization")
print("- Bidirectional LSTM with attention mechanism")
print("- Data augmentation (pitch shift, tempo, velocity, noise)")
print("- Class-balanced sampling and weighted loss")
print("- Better optimizer (AdamW) and scheduler (CosineAnnealing)")
print("- Early stopping and gradient clipping")
print("-" * 60)

train_losses, train_accuracies, val_accuracies = train_model_enhanced(
    model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs=25
)

print("Enhanced training completed!")

In [None]:
import matplotlib.pyplot as plt

# Plot training progress
plt.figure(figsize=(15, 5))

# Plot 1: Training and Validation Accuracy
plt.subplot(1, 3, 1)
epochs_range = range(1, len(train_accuracies) + 1)
plt.plot(epochs_range, train_accuracies, 'b-', label='Training Accuracy', linewidth=2)
plt.plot(epochs_range, val_accuracies, 'r-', label='Validation Accuracy', linewidth=2)
plt.title('Model Accuracy Over Time')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 2: Training Loss
plt.subplot(1, 3, 2)
plt.plot(epochs_range, train_losses, 'g-', label='Training Loss', linewidth=2)
plt.title('Training Loss Over Time')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 3: Learning Rate Schedule
plt.subplot(1, 3, 3)
# Reconstruct learning rate schedule for visualization
temp_scheduler = CosineAnnealingWarmRestarts(
    torch.optim.AdamW(model.parameters(), lr=0.001), 
    T_0=10, T_mult=2, eta_min=1e-6
)
lrs = []
for epoch in range(len(train_accuracies)):
    lrs.append(temp_scheduler.get_last_lr()[0])
    temp_scheduler.step()

plt.plot(epochs_range, lrs, 'm-', label='Learning Rate', linewidth=2)
plt.title('Learning Rate Schedule')
plt.xlabel('Epochs')
plt.ylabel('Learning Rate')
plt.yscale('log')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print final results
print(f"\n=== TRAINING SUMMARY ===")
print(f"Final Training Accuracy: {train_accuracies[-1]:.2f}%")
print(f"Final Validation Accuracy: {val_accuracies[-1]:.2f}%")
print(f"Best Validation Accuracy: {max(val_accuracies):.2f}%")
print(f"Improvement from baseline (~50%): +{max(val_accuracies) - 50:.2f}%")

In [None]:
def evaluate_model(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            
            test_loss += criterion(output, target).item()
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            # Store for detailed analysis
            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    accuracy = 100 * correct / total
    avg_loss = test_loss / len(test_loader)
    
    print(f"Test Results:")
    print(f"Test Loss: {avg_loss:.4f}")
    print(f"Test Accuracy: {accuracy:.2f}%")
    
    return accuracy, avg_loss, all_predictions, all_targets

In [None]:
# Evaluate the model on test set
test_accuracy, test_loss, predictions, targets = evaluate_model(model, test_loader, criterion, device)

# Show detailed results
from sklearn.metrics import classification_report, confusion_matrix

composer_names = ['Bach', 'Beethoven', 'Chopin', 'Mozart']
print("\nDetailed Classification Report:")
print(classification_report(targets, predictions, target_names=composer_names))

print("\nConfusion Matrix:")
cm = confusion_matrix(targets, predictions)
print(cm)

# Show per-composer accuracy
for i, composer in enumerate(composer_names):
    composer_correct = sum(1 for t, p in zip(targets, predictions) if t == i and p == i)
    composer_total = sum(1 for t in targets if t == i)
    composer_acc = 100 * composer_correct / composer_total if composer_total > 0 else 0
    print(f"{composer}: {composer_acc:.1f}% ({composer_correct}/{composer_total})")