# Drum MIDI Pattern Analysis

This notebook provides tools for analyzing drum patterns in MIDI files. It includes functions for:
- Loading and parsing MIDI files
- Extracting drum-specific information
- Analyzing timing, velocity, and pattern structure
- Visualizing drum patterns and statistics

## Setup

First, let's install the required packages if you don't have them already:

In [None]:
# Install required packages
!pip install mido pretty_midi matplotlib numpy pandas seaborn librosa

In [None]:
# Import required libraries
import os
import glob
import mido
import pretty_midi
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict, Counter
import librosa.display
from IPython.display import Audio

# Set plot style
plt.style.use('ggplot')
sns.set_context("notebook")

## MIDI Drum Mapping

In the General MIDI standard, channel 10 is reserved for drum instruments. Each note number corresponds to a specific drum sound.

In [None]:
# Define General MIDI drum mapping
GM_DRUM_MAP = {
    35: 'Acoustic Bass Drum',
    36: 'Bass Drum 1',
    37: 'Side Stick',
    38: 'Acoustic Snare',
    39: 'Hand Clap',
    40: 'Electric Snare',
    41: 'Low Floor Tom',
    42: 'Closed Hi Hat',
    43: 'High Floor Tom',
    44: 'Pedal Hi-Hat',
    45: 'Low Tom',
    46: 'Open Hi-Hat',
    47: 'Low-Mid Tom',
    48: 'Hi-Mid Tom',
    49: 'Crash Cymbal 1',
    50: 'High Tom',
    51: 'Ride Cymbal 1',
    52: 'Chinese Cymbal',
    53: 'Ride Bell',
    54: 'Tambourine',
    55: 'Splash Cymbal',
    56: 'Cowbell',
    57: 'Crash Cymbal 2',
    58: 'Vibraslap',
    59: 'Ride Cymbal 2',
    60: 'Hi Bongo',
    61: 'Low Bongo',
    62: 'Mute Hi Conga',
    63: 'Open Hi Conga',
    64: 'Low Conga',
    65: 'High Timbale',
    66: 'Low Timbale',
    67: 'High Agogo',
    68: 'Low Agogo',
    69: 'Cabasa',
    70: 'Maracas',
    71: 'Short Whistle',
    72: 'Long Whistle',
    73: 'Short Guiro',
    74: 'Long Guiro',
    75: 'Claves',
    76: 'Hi Wood Block',
    77: 'Low Wood Block',
    78: 'Mute Cuica',
    79: 'Open Cuica',
    80: 'Mute Triangle',
    81: 'Open Triangle'
}

# Common drum kit elements for simplified analysis
SIMPLIFIED_DRUM_MAP = {
    35: 'Kick', 36: 'Kick',  # Bass Drums
    38: 'Snare', 40: 'Snare',  # Snares
    42: 'Closed Hi-Hat', 44: 'Pedal Hi-Hat', 46: 'Open Hi-Hat',  # Hi-Hats
    49: 'Crash', 57: 'Crash',  # Crash Cymbals
    51: 'Ride', 59: 'Ride', 53: 'Ride Bell',  # Ride Cymbals
    41: 'Tom', 43: 'Tom', 45: 'Tom', 47: 'Tom', 48: 'Tom', 50: 'Tom'  # Toms
}

## Loading and Parsing MIDI Files

In [None]:
def load_midi_file(file_path):
    """Load a MIDI file using both mido and pretty_midi for different analysis tasks"""
    try:
        # Load with mido for lower-level access
        mido_midi = mido.MidiFile(file_path)
        
        # Load with pretty_midi for higher-level analysis
        pm_midi = pretty_midi.PrettyMIDI(file_path)
        
        return {'mido': mido_midi, 'pretty_midi': pm_midi, 'file_path': file_path}
    except Exception as e:
        print(f"Error loading {file_path}: {e}")
        return None

def extract_drum_notes(midi_data):
    """Extract drum notes from a MIDI file and return as a DataFrame"""
    pm_midi = midi_data['pretty_midi']
    
    # Find drum instruments (channel 9 in pretty_midi, which is 0-indexed)
    drum_instruments = [inst for inst in pm_midi.instruments if inst.is_drum]
    
    if not drum_instruments:
        print(f"No drum tracks found in {midi_data['file_path']}")
        return pd.DataFrame()
    
    notes_data = []
    for instrument in drum_instruments:
        for note in instrument.notes:
            note_name = GM_DRUM_MAP.get(note.pitch, f"Unknown ({note.pitch})")
            simplified_name = SIMPLIFIED_DRUM_MAP.get(note.pitch, "Other")
            
            notes_data.append({
                'start_time': note.start,
                'end_time': note.end,
                'duration': note.end - note.start,
                'pitch': note.pitch,
                'velocity': note.velocity,
                'instrument_name': instrument.name,
                'drum_name': note_name,
                'simplified_name': simplified_name
            })
    
    # Convert to DataFrame and sort by start time
    df = pd.DataFrame(notes_data)
    if not df.empty:
        df = df.sort_values(by='start_time')
    
    return df

def get_tempo_changes(midi_data):
    """Extract tempo changes from a MIDI file"""
    mido_midi = midi_data['mido']
    tempo_changes = []
    current_time = 0
    
    for track in mido_midi.tracks:
        track_time = 0
        for msg in track:
            track_time += msg.time
            if msg.type == 'set_tempo':
                tempo = mido.tempo2bpm(msg.tempo)
                tempo_changes.append({'time': track_time, 'tempo': tempo})
    
    return pd.DataFrame(tempo_changes).sort_values(by='time') if tempo_changes else pd.DataFrame({'time': [0], 'tempo': [120]})

## Analysis Functions

In [None]:
def analyze_drum_density(notes_df, bin_size=0.25):
    """Analyze density of drum hits over time"""
    if notes_df.empty:
        return None
    
    max_time = notes_df['start_time'].max()
    bins = np.arange(0, max_time + bin_size, bin_size)
    
    # Count hits per bin
    hist, _ = np.histogram(notes_df['start_time'], bins=bins)
    
    return pd.DataFrame({
        'time_start': bins[:-1],
        'time_end': bins[1:],
        'hits': hist
    })

def identify_common_patterns(notes_df, pattern_length=4, grid_resolution=16):
    """Identify common drum patterns by quantizing to a grid"""
    if notes_df.empty:
        return None
    
    # Get tempo from first tempo change (assuming constant tempo for simplicity)
    # In a real analysis, you'd account for tempo changes
    tempo = 120  # Default tempo (bpm)
    
    # Calculate beat duration and quantize grid
    beat_duration = 60 / tempo
    grid_size = beat_duration / (grid_resolution / 4)  # Grid for 16th notes if grid_resolution=16
    
    # Quantize note timings to the grid
    notes_df['quantized_position'] = (notes_df['start_time'] / grid_size).round().astype(int)
    
    # Calculate measure and position within measure
    notes_df['measure'] = notes_df['quantized_position'] // (grid_resolution * pattern_length)
    notes_df['position_in_measure'] = notes_df['quantized_position'] % (grid_resolution * pattern_length)
    
    # Group by instrument/drum type and create binary patterns
    patterns = defaultdict(list)
    for measure in sorted(notes_df['measure'].unique()):
        measure_df = notes_df[notes_df['measure'] == measure]
        
        # Create a pattern for each drum type
        for drum_type in SIMPLIFIED_DRUM_MAP.values():
            drum_notes = measure_df[measure_df['simplified_name'] == drum_type]
            pattern = np.zeros(grid_resolution * pattern_length, dtype=int)
            for _, note in drum_notes.iterrows():
                pos = note['position_in_measure']
                if 0 <= pos < len(pattern):
                    pattern[pos] = 1
            
            patterns[drum_type].append(pattern)
    
    # Count pattern occurrences
    pattern_counts = {}
    for drum_type, drum_patterns in patterns.items():
        if drum_patterns:
            # Convert patterns to tuple for counting
            pattern_tuples = [tuple(p) for p in drum_patterns]
            counts = Counter(pattern_tuples)
            pattern_counts[drum_type] = counts
    
    return pattern_counts

def analyze_velocity_dynamics(notes_df):
    """Analyze velocity patterns for dynamics"""
    if notes_df.empty:
        return None
    
    # Group by drum type
    velocity_stats = notes_df.groupby('simplified_name')['velocity'].agg(
        ['count', 'mean', 'std', 'min', 'max']
    ).reset_index()
    
    # Calculate velocity changes over time for each drum type
    velocity_trends = {}
    for drum_type in notes_df['simplified_name'].unique():
        drum_notes = notes_df[notes_df['simplified_name'] == drum_type].sort_values(by='start_time')
        if len(drum_notes) > 1:
            velocity_trends[drum_type] = drum_notes[['start_time', 'velocity']]
    
    return {
        'stats': velocity_stats,
        'trends': velocity_trends
    }

def analyze_timing_variations(notes_df, grid_resolution=16):
    """Analyze timing variations (human feel/groove)"""
    if notes_df.empty:
        return None
    
    # Assuming 120bpm for simplicity (adjust based on actual tempo)
    tempo = 120
    beat_duration = 60 / tempo
    grid_size = beat_duration / (grid_resolution / 4)
    
    # Calculate ideal grid positions
    notes_df['grid_position'] = (notes_df['start_time'] / grid_size).round() * grid_size
    
    # Calculate timing deviation from the grid
    notes_df['timing_deviation'] = notes_df['start_time'] - notes_df['grid_position']
    
    # Calculate deviation statistics by drum type
    timing_stats = notes_df.groupby('simplified_name')['timing_deviation'].agg(
        ['count', 'mean', 'std', 'min', 'max']
    ).reset_index()
    
    return {
        'deviations': notes_df[['start_time', 'simplified_name', 'timing_deviation']],
        'stats': timing_stats
    }

## Visualization Functions

In [None]:
def plot_drum_pattern(notes_df, start_time=0, duration=4, drum_order=None):
    """Plot a piano-roll style visualization of drum patterns"""
    if notes_df.empty:
        print("No notes to plot")
        return
    
    # Filter notes within the time range
    end_time = start_time + duration
    plot_notes = notes_df[
        (notes_df['start_time'] >= start_time) & 
        (notes_df['start_time'] < end_time)
    ]
    
    if plot_notes.empty:
        print(f"No notes found in time range {start_time} to {end_time}")
        return
        
    # Define drum order for visualization (most common drums at bottom)
    if drum_order is None:
        drum_order = [
            'Crash', 'Ride', 'Open Hi-Hat', 'Closed Hi-Hat', 'Pedal Hi-Hat',
            'Tom', 'Snare', 'Kick', 'Other'
        ]
    
    # Filter for drums we want to show
    mask = plot_notes['simplified_name'].isin(drum_order)
    plot_notes = plot_notes[mask]
    
    # Create mapping of drum types to y-axis positions
    drum_positions = {drum: i for i, drum in enumerate(drum_order)}
    
    # Create figure
    plt.figure(figsize=(12, 6))
    
    # Plot each note as a horizontal line with velocity-based color intensity
    for _, note in plot_notes.iterrows():
        if note['simplified_name'] in drum_positions:
            y_pos = drum_positions[note['simplified_name']]
            color_intensity = note['velocity'] / 127
            plt.plot(
                [note['start_time'], note['start_time'] + max(0.03, note['duration'])],
                [y_pos, y_pos],
                linewidth=6,
                color=plt.cm.Blues(0.3 + 0.7 * color_intensity),
                solid_capstyle='butt'
            )
    
    # Add vertical grid lines for beats
    beat_positions = np.arange(start_time, end_time + 0.25, 0.25)
    for i, pos in enumerate(beat_positions):
        alpha = 0.8 if i % 4 == 0 else 0.2  # Emphasize downbeats
        plt.axvline(x=pos, color='gray', linestyle='-', alpha=alpha, linewidth=0.5)
    
    # Set axis limits and labels
    plt.yticks(range(len(drum_order)), drum_order)
    plt.xlim(start_time, end_time)
    plt.ylim(-0.5, len(drum_order) - 0.5)
    plt.xlabel('Time (seconds)')
    plt.title('Drum Pattern')
    plt.grid(axis='y', linestyle='--', alpha=0.3)
    plt.tight_layout()
    
    plt.show()

def plot_drum_heatmap(pattern_counts, drum_type):
    """Plot a heatmap of the most common patterns for a specific drum type"""
    if drum_type not in pattern_counts or not pattern_counts[drum_type]:
        print(f"No patterns found for {drum_type}")
        return
    
    # Get the top 5 most common patterns
    top_patterns = pattern_counts[drum_type].most_common(5)
    if not top_patterns:
        print(f"No patterns found for {drum_type}")
        return
    
    # Create a heatmap of patterns
    pattern_length = len(top_patterns[0][0])
    pattern_matrix = np.array([pattern for pattern, _ in top_patterns])
    
    plt.figure(figsize=(12, 4))
    ax = sns.heatmap(
        pattern_matrix,
        cmap='Blues',
        cbar=False,
        linewidths=0.5,
        linecolor='gray'
    )
    
    # Add pattern count annotations
    for i, (_, count) in enumerate(top_patterns):
        plt.text(-0.5, i + 0.5, f"Count: {count}", va='center')
    
    # Set x-ticks for 16th notes
    if pattern_length == 64:  # Assuming 4 beats * 16 grid positions
        x_ticks = np.arange(0, pattern_length, 4)
        x_tick_labels = [f"{(i//16)+1}.{((i%16)//4)+1}" for i in x_ticks]
        plt.xticks(x_ticks + 0.5, x_tick_labels)
    
    plt.yticks(np.arange(len(top_patterns)) + 0.5, [f"Pattern {i+1}" for i in range(len(top_patterns))])
    plt.title(f"Most Common {drum_type} Patterns")
    plt.tight_layout()
    plt.show()

def plot_velocity_distribution(velocity_data):
    """Plot velocity distribution by drum type"""
    stats = velocity_data['stats']
    
    plt.figure(figsize=(10, 6))
    
    # Plot mean velocity with error bars showing standard deviation
    plt.errorbar(
        x=stats['simplified_name'],
        y=stats['mean'], 
        yerr=stats['std'],
        fmt='o',
        capsize=5,
        ecolor='gray',
        markersize=8
    )
    
    # Add count as text
    for i, row in stats.iterrows():
        plt.text(
            i, row['mean'] + row['std'] + 2, 
            f"n={row['count']}", 
            ha='center'
        )
    
    plt.grid(linestyle='--', alpha=0.7)
    plt.ylim(0, 130)
    plt.ylabel('Velocity')
    plt.title('Velocity Distribution by Drum Type')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

def plot_timing_deviation(timing_data):
    """Plot timing deviations (groove/feel)"""
    deviations = timing_data['deviations']
    
    plt.figure(figsize=(12, 6))
    
    # Create violin plot for timing deviations
    sns.violinplot(
        x='simplified_name',
        y='timing_deviation',
        data=deviations,
        inner='quartile',
        cut=0
    )
    
    # Add horizontal line at zero (perfect timing)
    plt.axhline(y=0, color='red', linestyle='--', alpha=0.7, linewidth=1)
    
    plt.title('Timing Deviations by Drum Type')
    plt.ylabel('Timing Deviation (seconds)')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

## Dataset Analysis

In [None]:
def analyze_midi_dataset(folder_path):
    """Analyze all MIDI files in a folder"""
    # Find all MIDI files
    midi_files = glob.glob(os.path.join(folder_path, "*.mid")) + glob.glob(os.path.join(folder_path, "*.midi"))
    
    if not midi_files:
        print(f"No MIDI files found in {folder_path}")
        return
    
    print(f"Found {len(midi_files)} MIDI files")
    
    # Analyze each file
    results = []
    for file_path in midi_files:
        print(f"Analyzing {os.path.basename(file_path)}...")
        
        # Load MIDI file
        midi_data = load_midi_file(file_path)
        if midi_data is None:
            continue
        
        # Extract drum notes
        notes_df = extract_drum_notes(midi_data)
        if notes_df.empty:
            continue
        
        # Basic file statistics
        file_stats = {
            'file_name': os.path.basename(file_path),
            'total_notes': len(notes_df),
            'duration': notes_df['end_time'].max(),
            'unique_drums': notes_df['simplified_name'].nunique(),
            'notes_df': notes_df
        }
        
        results.append(file_stats)
    
    # Create summary DataFrame
    summary_df = pd.DataFrame([{
        'file_name': r['file_name'],
        'total_notes': r['total_notes'],
        'duration': r['duration'],
        'unique_drums': r['unique_drums']
    } for r in results])
    
    return {
        'summary': summary_df,
        'detailed_results': results
    }

## Example Usage

In [None]:
# Example folder path (replace with your actual path)
dataset_path = './drum_midi_dataset'

# Analyze the dataset
analysis_results = analyze_midi_dataset(dataset_path)

# Display summary
if analysis_results:
    print("Dataset Summary:")
    display(analysis_results['summary'])

## Detailed Analysis of a Single MIDI File

Let's perform a detailed analysis of one file as an example:

In [None]:
def analyze_single_file(file_path):
    """Perform detailed analysis of a single MIDI file"""
    print(f"Analyzing {os.path.basename(file_path)}...")
    
    # Load MIDI file
    midi_data = load_midi_file(file_path)
    if midi_data is None:
        print("Failed to load file")
        return
    
    # Extract drum notes
    notes_df = extract_drum_notes(midi_data)
    if notes_df.empty:
        print("No drum notes found")
        return
    
    # Display basic statistics
    print(f"Total notes: {len(notes_df)}")
    print(f"Duration: {notes_df['end_time'].max():.2f} seconds")
    print(f"Unique drum types: {notes_df['simplified_name'].nunique()}")
    
    # Count by drum type
    drum_counts = notes_df['simplified_name'].value_counts()
    print("\nDrum counts:")
    display(drum_counts)
    
    # Plot drum pattern
    print("\nDrum pattern visualization (first 4 seconds):")
    plot_drum_pattern(notes_df, start_time=0, duration=4)
    
    # Analyze note density
    print("\nAnalyzing note density...")
    density = analyze_drum_density(notes_df)
    
    plt.figure(figsize=(12, 4))
    plt.plot(density['time_start'], density['hits'], '-o', alpha=0.7)
    plt.grid(linestyle='--', alpha=0.7)
    plt.title('Drum Hit Density Over Time')
    plt.xlabel('Time (seconds)')
    plt.ylabel('Hits per bin')
    plt.show()
    
    # Identify patterns
    print("\nIdentifying common patterns...")
    patterns = identify_common_patterns(notes_df)
    
    # Plot patterns for kick and snare
    if 'Kick' in patterns:
        print("\nKick drum patterns:")
        plot_drum_heatmap(patterns, 'Kick')
    
    if 'Snare' in patterns:
        print("\nSnare drum patterns:")
        plot_drum_heatmap(patterns, 'Snare')
    
    # Analyze velocity
    print("\nAnalyzing velocity dynamics...")
    velocity_data = analyze_velocity_dynamics(notes_df)
    plot_velocity_distribution(velocity_data)
    
    # Analyze timing
    print("\nAnalyzing timing variations (groove)...")
    timing_data = analyze_timing_variations(notes_df)
    plot_timing_deviation(timing_data)
    
    return {
        'notes_df': notes_df,
        'density': density,
        'patterns': patterns,
        'velocity': velocity_data,
        'timing': timing_data
    }

In [None]:
# Example file analysis (replace with an actual file path)
if analysis_results and len(analysis_results['detailed_results']) > 0:
    # Use the first file from our dataset
    example_file = analysis_results['detailed_results'][0]['file_name']
    file_path = os.path.join(dataset_path, example_file)
    
    # Analyze the file
    detailed_analysis = analyze_single_file(file_path)
else:
    print("No files available to analyze. Please provide a valid MIDI file path.")
    # If you have a specific file path, you can analyze it directly:
    # detailed_analysis = analyze_single_file('/path/to/your/drum_file.midi')

## Export Analysis Results

In [None]:
def export_results(analysis_results, output_folder='./analysis_results'):
    """Export analysis results to CSV files"""
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    
    # Export summary
    analysis_results['summary'].to_csv(os.path.join(output_folder, 'summary.csv'), index=False)
    print(f"Exported summary to {output_folder}/summary.csv")
    
    # Export detailed results for each file
    for result in analysis_results['detailed_results']:
        file_name = result['file_name'].replace('.mid', '').replace('.midi', '')
        
        # Export notes DataFrame
        notes_csv_path = os.path.join(output_folder, f"{file_name}_notes.csv")
        result['notes_df'].to_csv(notes_csv_path, index=False)
    
    print(f"Exported detailed results to {output_folder}")

In [None]:
# Export results (uncomment to use)
# if analysis_results:
#     export_results(analysis_results)

## Conclusion

This notebook provides a comprehensive set of tools for analyzing drum patterns in MIDI files. You can use it to:

1. Extract and visualize drum patterns
2. Identify common rhythmic patterns
3. Analyze velocity dynamics and timing variations (groove)
4. Analyze entire datasets of drum MIDI files

To use this notebook with your own data:
1. Update the `dataset_path` variable to point to your folder of MIDI files
2. Run the analysis cells
3. Explore the visualizations and results
4. Optionally export the analysis for further processing

This analysis can be particularly useful for studying different drumming styles, creating drum pattern libraries, or developing AI models for drum pattern generation.