In [47]:
# Imports and Setup
from collections import defaultdict
from typing import Dict, Tuple

import unittest
import os
import mido
import unittest

In [48]:
# Constants
DRUM_MAPPING = {
    35: (0, 'K'), 36: (0, 'K'), 38: (1, 'R'), 40: (1, 'R'),
    42: (2, 'Y'), 44: (2, 'Y'), 46: (3, 'B'), 49: (4, 'G'),
    51: (4, 'G'), 45: (2, 'Y'), 47: (2, 'Y'), 48: (2, 'Y'),
    50: (4, 'G'), 57: (4, 'G'),
}

def get_file_paths(midi_file_path: str) -> Tuple[str, str, str]:
    track_name = os.path.splitext(os.path.basename(midi_file_path))[0]
    base_dir = os.path.dirname(os.path.dirname(midi_file_path))
    chart_path = os.path.join(base_dir, 'chart_files', f'{track_name}.chart')
    return midi_file_path, chart_path, track_name

def create_midi_file(midi_file_path: str) -> mido.MidiFile:
    if not os.path.exists(midi_file_path):
        raise FileNotFoundError(f"File not found at {midi_file_path}")
    return mido.MidiFile(midi_file_path)

def initialize_song_metadata(track_name: str) -> Dict:
    return {
        "Name": track_name,
        "Artist": "Unknown",
        "Charter": "AI Generated",
        "Album": "Generated Charts",
        "Year": "2024",
        "Offset": 0,
        "Resolution": 192,
        "Player2": "bass",
        "Difficulty": 0,
        "PreviewStart": 0,
        "PreviewEnd": 0,
        "Genre": "rock"
    }

def initialize_chart_data(song_metadata: Dict) -> Dict:
    return {
        "Song": song_metadata,
        "SyncTrack": defaultdict(list),
        "Events": {},
        "ExpertDrums": defaultdict(list)
    }

def find_initial_tempo(mid: mido.MidiFile, default_tempo: int = 120000) -> int:
    for track in mid.tracks:
        for msg in track:
            if msg.type == 'set_tempo':
                return msg.tempo
    return default_tempo

def initialize_sync_track(chart_data: Dict, initial_tempo: int, time_sig: Tuple[int, int]) -> None:
    chart_data["SyncTrack"][0].append(f"B {initial_tempo}")
    chart_data["SyncTrack"][0].append(f"TS {time_sig[0]}")

def calculate_total_ticks(mid: mido.MidiFile) -> int:
    total_ticks = 0
    for track in mid.tracks:
        track_ticks = sum(msg.time for msg in track)
        total_ticks = max(total_ticks, track_ticks)
    return total_ticks

def calculate_song_length(total_ticks: int, tempo: int, ticks_per_beat: int) -> int:
    song_length = (total_ticks * tempo) / (ticks_per_beat * 1000000)
    return int(song_length * 1000)

def calculate_ticks_multiplier(ticks_per_beat: int, initial_tempo: int) -> float:
    return (ticks_per_beat * 192) / (initial_tempo / 1000000 * 60)

def process_note_message(msg: mido.Message, chart_tick: int, chart_data: Dict) -> None:
    if hasattr(msg, 'channel') and msg.channel == 9 and msg.note in DRUM_MAPPING:
        note_num, flag = DRUM_MAPPING[msg.note]
        note_str = f"N {note_num} 0{' ' + flag if flag else ''}"
        chart_data["ExpertDrums"][chart_tick].append(note_str)

def process_midi_messages(mid: mido.MidiFile, chart_data: Dict, ticks_multiplier: float) -> None:
    current_tick = 0
    for track in mid.tracks:
        for msg in track:
            current_tick += msg.time
            chart_tick = int(current_tick * ticks_multiplier / mid.ticks_per_beat)
            
            if msg.type == 'note_on' and msg.velocity > 0:
                process_note_message(msg, chart_tick, chart_data)
            elif msg.type == 'set_tempo':
                chart_data["SyncTrack"][chart_tick].append(f"B {msg.tempo}")
            elif msg.type == 'time_signature':
                chart_data["SyncTrack"][chart_tick].append(f"TS {msg.numerator}")

def generate_chart_text(chart_data: Dict) -> str:
    sections = []
    
    # Song section
    song_section = "[Song]\n{\n"
    song_section += "".join(f"  {key} = {value if isinstance(value, (int, float)) else f'{value}'}\n" 
                           for key, value in chart_data["Song"].items())
    song_section += "}\n"
    sections.append(song_section)
    
    # SyncTrack section
    sync_section = "[SyncTrack]\n{\n"
    sync_section += "".join(f"  {tick} = {event}\n" 
                          for tick in sorted(chart_data["SyncTrack"].keys()) 
                          for event in chart_data["SyncTrack"][tick])
    sync_section += "}\n"
    sections.append(sync_section)
    
    # Events section
    sections.append("[Events]\n{\n}\n")
    
    # ExpertDrums section
    drums_section = "[ExpertDrums]\n{\n"
    drums_section += "".join(f"  {tick} = {note}\n" 
                           for tick in sorted(chart_data["ExpertDrums"].keys()) 
                           for note in chart_data["ExpertDrums"][tick])
    drums_section += "}\n"
    sections.append(drums_section)
    
    return "\n".join(sections)

def write_chart_file(chart_path: str, chart_text: str) -> bool:
    try:
        os.makedirs(os.path.dirname(chart_path), exist_ok=True)
        with open(chart_path, 'w') as f:
            f.write(chart_text)
        return os.path.exists(chart_path) and os.path.getsize(chart_path) > 0
    except Exception as e:
        print(f"Error writing file: {e}")
        return False

def convert_midi_to_chart(midi_file_path: str) -> bool:
    try:
        midi_file_path, chart_path, track_name = get_file_paths(midi_file_path)
        print(f"Converting {midi_file_path} to {chart_path}")
        
        mid = create_midi_file(midi_file_path)
        song_metadata = initialize_song_metadata(track_name)
        chart_data = initialize_chart_data(song_metadata)
        
        initial_tempo = find_initial_tempo(mid)
        initialize_sync_track(chart_data, initial_tempo, (4, 4))
        
        total_ticks = calculate_total_ticks(mid)
        song_metadata["Length"] = calculate_song_length(total_ticks, initial_tempo, mid.ticks_per_beat)
        
        ticks_multiplier = calculate_ticks_multiplier(mid.ticks_per_beat, initial_tempo)
        process_midi_messages(mid, chart_data, ticks_multiplier)
        
        chart_text = generate_chart_text(chart_data)
        success = write_chart_file(chart_path, chart_text)
        
        if success:
            print(f"Successfully created chart at {chart_path}")
        else:
            print("Failed to create chart file")
        
        return success
    except Exception as e:
        print(f"Error during conversion: {e}")
        return False

if __name__ == "__main__":
    midi_path = '../songs/midi_songs/test-split.mid'
    convert_midi_to_chart(midi_path)


Converting ../songs/midi_songs/test-split.mid to ../songs\chart_files\test-split.chart
Successfully created chart at ../songs\chart_files\test-split.chart


In [49]:
class TestChartProcessing(unittest.TestCase):
    def setUp(self):
        self.midi_file_path = '../songs/midi_songs/test-split.mid'

    def test_initialize_song_metadata(self):
        """Test if song metadata is correctly initialized"""
        metadata = initialize_song_metadata('test-split')
        
        self.assertEqual(metadata['Name'], 'test-split')
        self.assertEqual(metadata['Resolution'], 192)
        self.assertEqual(metadata['Charter'], 'AI Generated')
        self.assertEqual(metadata['Year'], '2024')

    def test_calculate_song_length(self):
        """Test if song length calculation is accurate"""
        total_ticks = 1000
        tempo = 500000  # 120 BPM in microseconds
        ticks_per_beat = 480
        
        length = calculate_song_length(total_ticks, tempo, ticks_per_beat)
        
        self.assertEqual(length, 1041)  # Expected milliseconds for these parameters

    def test_generate_chart_text(self):
        """Test if chart text is generated correctly"""
        test_data = {
            "Song": {"Name": "test", "Artist": "Unknown"},
            "SyncTrack": {0: ["B 120000", "TS 4"]},
            "Events": {},
            "ExpertDrums": {192: ["N 1 0 R"]}
        }
        
        chart_text = generate_chart_text(test_data)
        
        self.assertIn("[Song]", chart_text)
        self.assertIn("Name = test", chart_text)
        self.assertIn("[SyncTrack]", chart_text)
        self.assertIn("192 = N 1 0 R", chart_text)

    def test_process_note_message(self):
        """Test if MIDI drum notes are correctly mapped"""
        chart_data = {"ExpertDrums": defaultdict(list)}
        msg = mido.Message('note_on', note=35, velocity=64, channel=9)
        chart_tick = 192
        
        process_note_message(msg, chart_tick, chart_data)
        
        expected_note = "N 0 0 K"
        self.assertEqual(chart_data["ExpertDrums"][chart_tick][0], expected_note)

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)


....
----------------------------------------------------------------------
Ran 4 tests in 0.002s

OK
