In [5]:
# Imports and Setup
from mido import MidiFile, MidiTrack, MetaMessage

import os
import unittest

In [6]:
def extract_drum_track(input_file, output_file):
    # Initialize MIDI files
    in_mid = MidiFile(input_file)
    out_mid = MidiFile(ticks_per_beat=in_mid.ticks_per_beat)
    
    # Create drum track
    drum_track = MidiTrack()
    out_mid.tracks.append(drum_track)
    
    # Collect messages with absolute times
    all_messages = []
    current_time = 0
    
    # Process each track
    for track in in_mid.tracks:
        current_time = 0
        for msg in track:
            current_time += msg.time
            if (isinstance(msg, MetaMessage) and msg.type in ('set_tempo', 'time_signature')) or \
               (msg.type in ('note_on', 'note_off') and msg.channel == 9) or \
               (msg.type not in ('note_on', 'note_off')):
                all_messages.append((current_time, msg))
    
    # Sort and convert to delta times
    all_messages.sort(key=lambda x: x[0])
    print([msg.channel for msg in all_messages if hasattr(msg, 'channel')])
    last_time = 0
    
    for abs_time, msg in all_messages:
        delta = abs_time - last_time
        new_msg = msg.copy(time=delta)
        drum_track.append(new_msg)
        last_time = abs_time
    
    # Save and return
    out_mid.save(output_file)
    return out_mid

In [7]:
# Define the input and output directories
input_file = '../songs/midi_songs/for those about to rock.mid'
output_file = '../songs/midi_songs/test-split.mid'

extract_drum_track(input_file, output_file)

[]


MidiFile(type=1, ticks_per_beat=120, tracks=[
  MidiTrack([
    MetaMessage('time_signature', numerator=4, denominator=4, clocks_per_click=24, notated_32nd_notes_per_beat=8, time=0),
    MetaMessage('key_signature', key='D', time=0),
    MetaMessage('text', text='* ENGL', time=0),
    Message('program_change', channel=11, program=127, time=0),
    MetaMessage('set_tempo', tempo=444444, time=0),
    MetaMessage('text', text='* For those about to rock', time=0),
    Message('program_change', channel=3, program=103, time=0),
    Message('program_change', channel=0, program=30, time=0),
    Message('program_change', channel=1, program=29, time=0),
    Message('program_change', channel=9, program=16, time=0),
    Message('program_change', channel=2, program=34, time=0),
    Message('program_change', channel=12, program=127, time=0),
    Message('program_change', channel=4, program=30, time=0),
    Message('control_change', channel=11, control=7, value=127, time=0),
    Message('program_chan

In [8]:
class TestMidiDrumExtractor(unittest.TestCase):
    def setUp(self):
        self.input_file = '../songs/midi_songs/for those about to rock.mid'
        self.output_file = '../songs/midi_songs/test-split.mid'
    
    def test_file_creation(self):
        """Test if output file is created"""
        result = extract_drum_track(self.input_file, self.output_file)
        self.assertTrue(os.path.exists(self.output_file))
    
    def test_output_structure(self):
        """Test if output MIDI has correct structure"""
        result = extract_drum_track(self.input_file, self.output_file)
        self.assertIsInstance(result, MidiFile)
        self.assertEqual(len(result.tracks), 1)
    
    def test_drum_channel(self):
        """Test if output contains only drum channel messages"""
        result = extract_drum_track(self.input_file, self.output_file)
        for track in result.tracks:
            for msg in track:
                if hasattr(msg, 'channel') and msg.type == 'note_on':
                    self.assertEqual(msg.channel, 9)
    
    def test_no_non_drum_messages(self):
        """Test that no non-drum messages exist in the output."""
        result = extract_drum_track(self.input_file, self.output_file)
        for track in result.tracks:
            for msg in track:
                if hasattr(msg, 'channel') and msg.type == 'note_on':
                    self.assertEqual(msg.channel, 9)
                else:
                    self.assertNotIn(msg.type, ['note_on', 'note_off'])

    def test_timing_accuracy(self):
        """Test that drum messages maintain accurate timing."""
        original_mid = MidiFile(self.input_file)
        result = extract_drum_track(self.input_file, self.output_file)
        
        original_times = []
        result_times = []
        
        for track in original_mid.tracks:
            current_time = 0
            for msg in track:
                current_time += msg.time
                if hasattr(msg, 'channel') and msg.channel == 9:
                    original_times.append(current_time)
        
        current_time = 0
        for track in result.tracks:
            for msg in track:
                current_time += msg.time
                if hasattr(msg, 'channel') and msg.channel == 9:
                    result_times.append(current_time)
        
        self.assertEqual(original_times, result_times)

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

.

[]


.

[]
[]


.

[]


.

[]


.
----------------------------------------------------------------------
Ran 5 tests in 0.626s

OK
