# Getting Started with SOUSA Dataset

This notebook demonstrates how to load, explore, and use the SOUSA (Synthetic Open Unified Snare Assessment) dataset for machine learning tasks.

**Contents:**
1. Loading the Dataset
2. Exploring the Data Structure
3. Playing Audio Samples
4. Visualizing Score Distributions
5. Filtering for ML Training
6. Basic Model Training Example

## 1. Loading the Dataset

SOUSA can be loaded from HuggingFace Hub or from a local directory.

In [None]:
# Option A: Load from HuggingFace Hub
from datasets import load_dataset

# This downloads the dataset (first time only)
dataset = load_dataset("zkeown/sousa")

print(f"Dataset loaded with splits: {list(dataset.keys())}")
print(f"Train samples: {len(dataset['train']):,}")
print(f"Validation samples: {len(dataset['validation']):,}")
print(f"Test samples: {len(dataset['test']):,}")

In [None]:
# Option B: Load from local directory (if you generated the dataset yourself)
import pandas as pd
from pathlib import Path

# Adjust this path to your local dataset
LOCAL_DATASET_DIR = Path("../output/dataset")

if LOCAL_DATASET_DIR.exists():
    samples_df = pd.read_parquet(LOCAL_DATASET_DIR / "labels" / "samples.parquet")
    exercises_df = pd.read_parquet(LOCAL_DATASET_DIR / "labels" / "exercises.parquet")
    print(f"Loaded {len(samples_df):,} samples from local directory")

## 2. Exploring the Data Structure

Each sample contains:
- **Audio**: FLAC file of the performance
- **Metadata**: rudiment, skill tier, tempo, etc.
- **Scores**: Performance quality scores (0-100 scale)

In [None]:
# Look at a single sample
sample = dataset['train'][0]

print("Sample fields:")
for key, value in sample.items():
    if key == 'audio':
        print(f"  {key}: array shape {value['array'].shape}, sr={value['sampling_rate']}")
    else:
        print(f"  {key}: {value}")

In [None]:
# View skill tier distribution
import collections

tier_counts = collections.Counter(dataset['train']['skill_tier'])
print("Skill tier distribution:")
for tier in ['beginner', 'intermediate', 'advanced', 'professional']:
    count = tier_counts.get(tier, 0)
    pct = 100 * count / len(dataset['train'])
    print(f"  {tier:15s}: {count:6,} ({pct:5.1f}%)")

In [None]:
# View rudiment distribution (top 10)
rudiment_counts = collections.Counter(dataset['train']['rudiment_slug'])
print("Top 10 rudiments:")
for rudiment, count in rudiment_counts.most_common(10):
    print(f"  {rudiment:30s}: {count:5,}")

## 3. Playing Audio Samples

You can listen to audio samples directly in the notebook.

In [None]:
from IPython.display import Audio, display

def play_sample(sample, label=None):
    """Play audio from a dataset sample."""
    audio_array = sample['audio']['array']
    sample_rate = sample['audio']['sampling_rate']
    
    if label:
        print(label)
    print(f"  Rudiment: {sample['rudiment_slug']}")
    print(f"  Skill Tier: {sample['skill_tier']}")
    print(f"  Tempo: {sample['tempo_bpm']} BPM")
    print(f"  Overall Score: {sample['overall_score']:.1f}")
    
    display(Audio(audio_array, rate=sample_rate))

In [None]:
# Listen to a beginner vs professional comparison
train_data = dataset['train']

# Find a paradiddle from each tier
rudiment = 'single_paradiddle'

beginner_sample = None
professional_sample = None

for sample in train_data:
    if sample['rudiment_slug'] == rudiment:
        if sample['skill_tier'] == 'beginner' and beginner_sample is None:
            beginner_sample = sample
        elif sample['skill_tier'] == 'professional' and professional_sample is None:
            professional_sample = sample
    if beginner_sample and professional_sample:
        break

if beginner_sample:
    play_sample(beginner_sample, "BEGINNER:")
    print()

if professional_sample:
    play_sample(professional_sample, "PROFESSIONAL:")

## 4. Visualizing Score Distributions

Explore how scores differ across skill tiers.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Convert to DataFrame for easier analysis
train_df = dataset['train'].to_pandas()

# Score columns
score_cols = ['overall_score', 'timing_accuracy', 'timing_consistency', 
              'velocity_control', 'hand_balance']

# Filter to score columns that exist
score_cols = [c for c in score_cols if c in train_df.columns]

print(f"Available score columns: {score_cols}")

In [None]:
# Plot score distributions by skill tier
fig, axes = plt.subplots(1, len(score_cols), figsize=(4*len(score_cols), 4))
if len(score_cols) == 1:
    axes = [axes]

tier_order = ['beginner', 'intermediate', 'advanced', 'professional']
colors = ['#e74c3c', '#f39c12', '#2ecc71', '#3498db']

for ax, col in zip(axes, score_cols):
    for tier, color in zip(tier_order, colors):
        tier_data = train_df[train_df['skill_tier'] == tier][col]
        ax.hist(tier_data, bins=30, alpha=0.6, label=tier.capitalize(), color=color)
    ax.set_xlabel(col.replace('_', ' ').title())
    ax.set_ylabel('Count')
    ax.legend(fontsize=8)

plt.tight_layout()
plt.suptitle('Score Distributions by Skill Tier', y=1.02)
plt.show()

In [None]:
# Summary statistics by tier
print("Mean Overall Score by Skill Tier:")
print("-" * 40)

for tier in tier_order:
    tier_scores = train_df[train_df['skill_tier'] == tier]['overall_score']
    print(f"{tier.capitalize():15s}: {tier_scores.mean():5.1f} +/- {tier_scores.std():4.1f}")

In [None]:
# Correlation heatmap of scores
import seaborn as sns

score_df = train_df[score_cols]
corr_matrix = score_df.corr()

plt.figure(figsize=(8, 6))
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', center=0, fmt='.2f')
plt.title('Score Correlation Matrix')
plt.tight_layout()
plt.show()

## 5. Filtering for ML Training

Filter the dataset based on your specific needs.

In [None]:
# Filter by skill tier
advanced_professional = dataset['train'].filter(
    lambda x: x['skill_tier'] in ['advanced', 'professional']
)
print(f"Advanced + Professional samples: {len(advanced_professional):,}")

In [None]:
# Filter by rudiment
paradiddles = dataset['train'].filter(
    lambda x: 'paradiddle' in x['rudiment_slug']
)
print(f"Paradiddle samples: {len(paradiddles):,}")

In [None]:
# Filter by score range (e.g., samples with score > 80)
high_quality = dataset['train'].filter(
    lambda x: x['overall_score'] > 80
)
print(f"High quality (score > 80) samples: {len(high_quality):,}")

In [None]:
# Create a binary classification dataset (beginner vs professional)
def add_binary_label(example):
    if example['skill_tier'] == 'professional':
        example['is_professional'] = 1
    else:
        example['is_professional'] = 0
    return example

binary_dataset = dataset['train'].filter(
    lambda x: x['skill_tier'] in ['beginner', 'professional']
).map(add_binary_label)

print(f"Binary classification samples: {len(binary_dataset):,}")
print(f"  Beginner: {sum(1 for x in binary_dataset if x['is_professional'] == 0):,}")
print(f"  Professional: {sum(1 for x in binary_dataset if x['is_professional'] == 1):,}")

## 6. Basic Model Training Example

A minimal example showing how to train a classifier on SOUSA.

**Note:** This is a simplified example for demonstration. For production use, consider:
- Using a proper audio feature extractor (e.g., mel spectrograms)
- Larger models (e.g., Audio Spectrogram Transformer)
- Proper hyperparameter tuning

In [None]:
# Skip this cell if you don't have transformers/torch installed
try:
    import torch
    import torch.nn as nn
    from torch.utils.data import DataLoader
    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False
    print("PyTorch not available. Skip the model training example.")
    print("Install with: pip install torch")

In [None]:
if TORCH_AVAILABLE:
    # Simple feature extraction: RMS energy in chunks
    def extract_features(audio_array, n_chunks=16):
        """Extract simple RMS features from audio."""
        chunk_size = len(audio_array) // n_chunks
        features = []
        for i in range(n_chunks):
            chunk = audio_array[i*chunk_size:(i+1)*chunk_size]
            rms = np.sqrt(np.mean(chunk**2))
            features.append(rms)
        return np.array(features, dtype=np.float32)
    
    # Prepare data
    def prepare_batch(examples):
        features = []
        labels = []
        for ex in examples:
            feat = extract_features(ex['audio']['array'])
            features.append(feat)
            # Skill tier to numeric (0-3)
            tier_map = {'beginner': 0, 'intermediate': 1, 'advanced': 2, 'professional': 3}
            labels.append(tier_map[ex['skill_tier']])
        return torch.tensor(np.stack(features)), torch.tensor(labels)
    
    print("Feature extraction function defined.")

In [None]:
if TORCH_AVAILABLE:
    # Simple classifier
    class SimpleClassifier(nn.Module):
        def __init__(self, input_dim=16, hidden_dim=32, num_classes=4):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(hidden_dim, num_classes),
            )
        
        def forward(self, x):
            return self.net(x)
    
    model = SimpleClassifier()
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
if TORCH_AVAILABLE:
    # Quick training loop on a small subset
    # In practice, use the full dataset with proper batching
    
    # Get small training and validation sets
    train_subset = dataset['train'].shuffle(seed=42).select(range(500))
    val_subset = dataset['validation'].shuffle(seed=42).select(range(100))
    
    # Prepare data
    X_train, y_train = prepare_batch(train_subset)
    X_val, y_val = prepare_batch(val_subset)
    
    # Training
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    print("Training simple classifier (500 samples, 50 epochs)...")
    for epoch in range(50):
        model.train()
        optimizer.zero_grad()
        outputs = model(X_train)
        loss = criterion(outputs, y_train)
        loss.backward()
        optimizer.step()
        
        if (epoch + 1) % 10 == 0:
            model.eval()
            with torch.no_grad():
                val_outputs = model(X_val)
                val_preds = val_outputs.argmax(dim=1)
                val_acc = (val_preds == y_val).float().mean().item()
            print(f"Epoch {epoch+1:3d}: loss={loss.item():.4f}, val_acc={val_acc:.3f}")

In [None]:
if TORCH_AVAILABLE:
    # Evaluate on validation set
    model.eval()
    with torch.no_grad():
        val_outputs = model(X_val)
        val_preds = val_outputs.argmax(dim=1)
    
    # Confusion matrix
    from sklearn.metrics import confusion_matrix, classification_report
    
    tier_names = ['beginner', 'intermediate', 'advanced', 'professional']
    
    print("Classification Report:")
    print(classification_report(y_val.numpy(), val_preds.numpy(), target_names=tier_names))
    
    print("\nNote: This is a simplified example with basic features.")
    print("For better results, use mel spectrograms or pre-trained audio models.")

## Next Steps

Now that you've explored the dataset, here are some ideas for what to do next:

1. **Score Regression**: Train a model to predict overall_score from audio
2. **Rudiment Classification**: Classify which of the 40 rudiments is being played
3. **Skill Assessment**: Build a system that provides feedback on drumming quality
4. **Transfer Learning**: Fine-tune a pre-trained audio model (e.g., Wav2Vec2, HuBERT)

For more information, see the [SOUSA documentation](https://github.com/zkeown/rudimentary).