In [1]:
from collections import defaultdict
import os
import mido

In [None]:
class DrumChartGenerator:
    def __init__(self):
        self.CHART_RESOLUTION = 192
        self.DRUM_MAPPING = {
            35: (0, 'K'),  # Acoustic Bass Drum
            36: (0, 'K'),  # Bass Drum (Kick)
            38: (1, 'R'),  # Acoustic Snare
            40: (1, 'R'),  # Electric Snare
            42: (2, 'Y'),  # Closed Hi-Hat
            44: (2, 'Y'),  # Pedal Hi-Hat
            46: (3, 'B'),  # Open Hi-Hat
            49: (4, 'G'),  # Crash Cymbal 1
            51: (3, 'B'),  # Ride Cymbal
            45: (4, 'G'),  # Low Tom
            47: (3, 'B'),  # Mid Tom
            48: (2, 'Y'),  # High Mid Tom
            50: (2, 'Y'),  # High Tom
            57: (4, 'G'),  # Crash Cymbal 2
        }

    def setup_paths(self, midi_file_path):
        self.midi_file_path = midi_file_path
        self.track_name = os.path.splitext(os.path.basename(midi_file_path))[0]
        self.chart_path = f'../../songs/chart_files/{self.track_name}.chart'
        return self.track_name, self.chart_path

    def initialize_song_metadata(self):
        return {
            "Name": f"\"{self.track_name}\"",
            "Artist": "\"Unknown\"",
            "Charter": "\"ACE\"",
            "Album": "\"Generated Charts\"",
            "Year": "\"2024\"",
            "Offset": 0,
            "Resolution": self.CHART_RESOLUTION,
            "Player2": "\"bass\"",
            "Difficulty": 0,
            "PreviewStart": 0,
            "PreviewEnd": 0,
            "Genre": "\"Rock\""
        }

    def initialize_chart_data(self):
        return {
            "Song": self.initialize_song_metadata(),
            "SyncTrack": defaultdict(list),
            "Events": {},
            "ExpertDrums": defaultdict(list)
        }

    def load_midi_file(self):
        if os.path.exists(self.midi_file_path):
            return mido.MidiFile(self.midi_file_path)
        raise FileNotFoundError(f"File not found at {self.midi_file_path}")

    def process_midi_track(self, mid, chart_data):
        merged_track = mido.merge_tracks(mid.tracks)
        chart_data["SyncTrack"][0].append("TS 4")
        current_tick = 0

        for msg in merged_track:
            current_tick += msg.time
            chart_tick = int(current_tick * self.CHART_RESOLUTION / mid.ticks_per_beat)

            if msg.type == 'set_tempo':
                bpm = int(mido.tempo2bpm(msg.tempo))
                ch_tempo = bpm * 1000
                chart_data["SyncTrack"][chart_tick].append(f"B {ch_tempo}")

            if msg.type == 'note_on' and msg.velocity > 0:
                self.process_drum_note(msg, chart_tick, chart_data)

        return chart_data

    def process_drum_note(self, msg, chart_tick, chart_data):
        if hasattr(msg, 'channel') and msg.channel == 9 and msg.note in self.DRUM_MAPPING:
            note_num, flag = self.DRUM_MAPPING[msg.note]
            note_str = f"N {note_num} 0{' ' + flag if flag else ''}"
            chart_data["ExpertDrums"][chart_tick].append(note_str)

            # Apply cymbals
            if msg.note in [42, 44]:  # Yellow cymbal
                chart_data["ExpertDrums"][chart_tick].append(f"N 66 0{' ' + flag if flag else ''}")
            elif msg.note in [46, 51]:  # Blue cymbal
                chart_data["ExpertDrums"][chart_tick].append(f"N 67 0{' ' + flag if flag else ''}")
            elif msg.note in [49, 57]:  # Green cymbal
                chart_data["ExpertDrums"][chart_tick].append(f"N 68 0{' ' + flag if flag else ''}")

    def generate_chart_text(self, chart_data):
        chart_text = "[Song]\n{\n"
        for key, value in chart_data["Song"].items():
            chart_text += f"  {key} = {value if isinstance(value, (int, float)) else f'{value}'}\n"
        chart_text += "}\n"

        chart_text += "[SyncTrack]\n{\n"
        for tick in sorted(chart_data["SyncTrack"].keys()):
            for event in chart_data["SyncTrack"][tick]:
                chart_text += f"  {tick} = {event}\n"
        chart_text += "}\n"

        chart_text += "[Events]\n{\n}\n"

        chart_text += "[ExpertDrums]\n{\n"
        for tick in sorted(chart_data["ExpertDrums"].keys()):
            for note in chart_data["ExpertDrums"][tick]:
                chart_text += f"  {tick} = {note[:-2]}\n"
        chart_text += "}"
        
        return chart_text

    def save_chart_file(self, chart_text):
        os.makedirs(os.path.dirname(self.chart_path), exist_ok=True)
        with open(self.chart_path, 'w') as f:
            f.write(chart_text)
        
        if os.path.exists(self.chart_path) and os.path.getsize(self.chart_path) > 0:
            return True
        return False

    def generate_chart(self, midi_file_path):
        self.setup_paths(midi_file_path)
        chart_data = self.initialize_chart_data()
        mid = self.load_midi_file()
        chart_data = self.process_midi_track(mid, chart_data)
        chart_text = self.generate_chart_text(chart_data)
        return self.save_chart_file(chart_text)

In [None]:
import os
import tempfile
import unittest
from collections import defaultdict
from unittest.mock import patch

# Assume DrumChartGenerator is imported or defined in the same module.

class TestDrumChartGenerator(unittest.TestCase):
    def setUp(self):
        self.generator = DrumChartGenerator()

    def test_setup_paths(self):
        track_name, chart_path = self.generator.setup_paths("dummy/path/test.mid")
        self.assertEqual(track_name, "test")
        self.assertEqual(chart_path, "../../songs/chart_files/test.chart")

    def test_initialize_song_metadata(self):
        self.generator.track_name = "dummy"
        metadata = self.generator.initialize_song_metadata()
        expected = {
            "Name": "\"dummy\"",
            "Artist": "\"Unknown\"",
            "Charter": "\"ACE\"",
            "Album": "\"Generated Charts\"",
            "Year": "\"2024\"",
            "Offset": 0,
            "Resolution": self.generator.CHART_RESOLUTION,
            "Player2": "\"bass\"",
            "Difficulty": 0,
            "PreviewStart": 0,
            "PreviewEnd": 0,
            "Genre": "\"Rock\""
        }
        self.assertDictEqual(metadata, expected)

    def test_initialize_chart_data(self):
        self.generator.track_name = "dummy"
        chart_data = self.generator.initialize_chart_data()
        self.assertIn("Song", chart_data)
        self.assertIn("SyncTrack", chart_data)
        self.assertIn("Events", chart_data)
        self.assertIn("ExpertDrums", chart_data)
        self.assertEqual(chart_data["Song"], self.generator.initialize_song_metadata())
        self.assertIsInstance(chart_data["SyncTrack"], defaultdict)
        self.assertIsInstance(chart_data["ExpertDrums"], defaultdict)

    @patch('os.path.exists', return_value=True)
    @patch('mido.MidiFile')
    def test_load_midi_file_valid(self, mock_midi_file, mock_exists):
        self.generator.midi_file_path = "dummy.mid"
        dummy_midi = type("DummyMidi", (), {"ticks_per_beat": 480, "tracks": []})()
        mock_midi_file.return_value = dummy_midi
        midi = self.generator.load_midi_file()
        self.assertEqual(midi, dummy_midi)

    def test_load_midi_file_not_found(self):
        self.generator.midi_file_path = "nonexistent.mid"
        if os.path.exists(self.generator.midi_file_path):
            os.remove(self.generator.midi_file_path)
        with self.assertRaises(FileNotFoundError):
            self.generator.load_midi_file()

    def test_process_drum_note_closed_hihat(self):
        class DummyMsg:
            pass
        msg = DummyMsg()
        msg.channel = 9
        msg.note = 42  # Closed Hi-Hat maps to (2, 'Y')
        msg.velocity = 64
        chart_data = {"ExpertDrums": defaultdict(list)}
        chart_tick = 100
        self.generator.process_drum_note(msg, chart_tick, chart_data)
        self.assertIn("N 2 0 Y", chart_data["ExpertDrums"][chart_tick])
        self.assertIn("N 66 0 Y", chart_data["ExpertDrums"][chart_tick])

    def test_generate_chart_text(self):
        chart_data = {
            "Song": {"Name": "\"dummy\"", "Artist": "\"Unknown\""},
            "SyncTrack": defaultdict(list, {0: ["TS 4", "B 120000"]}),
            "Events": {},
            "ExpertDrums": defaultdict(list, {96: ["N 0 0 K", "N 66 0 Y"]})
        }
        chart_text = self.generator.generate_chart_text(chart_data)
        self.assertIn("[Song]", chart_text)
        self.assertIn("Name = \"dummy\"", chart_text)
        self.assertIn("[SyncTrack]", chart_text)
        self.assertIn("0 = TS 4", chart_text)
        self.assertIn("0 = B 120000", chart_text)
        self.assertIn("[ExpertDrums]", chart_text)
        self.assertIn("96 = N 0 0", chart_text)

    def test_save_chart_file(self):
        with tempfile.TemporaryDirectory() as temp_dir:
            self.generator.chart_path = os.path.join(temp_dir, "test.chart")
            chart_text = "dummy chart content"
            result = self.generator.save_chart_file(chart_text)
            self.assertTrue(result)
            with open(self.generator.chart_path, 'r') as f:
                content = f.read()
            self.assertEqual(content, chart_text)

    def test_generate_chart(self):
        fake_midi_path = "dummy.mid"
        self.generator.setup_paths(fake_midi_path)
        DummyMidi = type("DummyMidi", (), {}) 
        dummy_midi = DummyMidi()
        dummy_midi.ticks_per_beat = 480
        dummy_midi.tracks = []
        self.generator.load_midi_file = lambda: dummy_midi
        self.generator.process_midi_track = lambda mid, d: d
        self.generator.generate_chart_text = lambda d: "chart text"
        self.generator.save_chart_file = lambda text: True
        result = self.generator.generate_chart(fake_midi_path)
        self.assertTrue(result)

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

.........
----------------------------------------------------------------------
Ran 9 tests in 0.008s

OK
