# dataset building
## imports

In [None]:
import os
from mido import MidiFile, MetaMessage, second2tick, tick2second, bpm2tempo
import pretty_midi
import numpy as np

from collections import Counter
import mido
import matplotlib.pyplot as plt

from rich import print
from rich.progress import track

from typing import List

## helper functions

In [None]:
def set_tempo(input_file_path, bpm) -> None:
    mid = MidiFile(input_file_path)
    tempo = bpm2tempo(bpm)
    mid.tracks[0].insert(0, MetaMessage("set_tempo", tempo=tempo, time=0))
    mid.save(input_file_path)

In [None]:
def modify_end_of_track(midi_file_path, new_end_time, tempo):
    mid = MidiFile(midi_file_path)
    new_end_time_t = second2tick(new_end_time, 220, bpm2tempo(tempo))
    # print(f"\t{midi_file_path} bpm2tempo(tempo)}")
    # mid.print_tracks()

    for i, track in enumerate(mid.tracks):
        total_time_t = 0
        # Remove existing 'end_of_track' messages and calculate last note time
        for msg in track:
            if msg.type == "note_on":
                total_time_t += msg.time
            if msg.type == "end_of_track":
                track.remove(msg)
                # Add a new 'end_of_track' message at the calculated offset time
                offset = (
                    new_end_time_t - total_time_t
                    if new_end_time_t > total_time_t
                    else 0
                )
                track.append(MetaMessage("end_of_track", time=offset))

                net = tick2second(new_end_time_t, mid.ticks_per_beat, bpm2tempo(tempo))
                tt = tick2second(total_time_t, mid.ticks_per_beat, bpm2tempo(tempo))
                o = tick2second(offset, mid.ticks_per_beat, bpm2tempo(tempo))
                print(
                    f"\tmodified '{os.path.basename(midi_file_path)}' track {i} to have end time {net:.03f}: {tt:.03f} -> offset is {o:.03f}"
                )

    # Save the modified MIDI file
    os.remove(midi_file_path)
    mid.save(midi_file_path)

In [None]:
def segment_midi(
    midi_file_path: str,
    output_dir: str,
    num_beats: int = 8,
    do_shift: bool = False,
    do_transpose: bool = False,
) -> List[str]:
    filename = os.path.basename(midi_file_path)[:-4]
    target_tempo = int(filename.split("-")[1])
    set_tempo(midi_file_path, target_tempo)
    midi_pm = pretty_midi.PrettyMIDI(midi_file_path)
    total_length = midi_pm.get_end_time()
    segment_length = num_beats * 60 / target_tempo  # in seconds
    num_segments = int(np.round(total_length / segment_length))
    eighth_beat = segment_length / num_beats / 8  # eighth of a beat

    print(
        f"\tbreaking '{filename}' ({total_length:.03f} s at {target_tempo} bpm) into {num_segments:03d} segments of {segment_length:.03f}s\n\t(pre window is {eighth_beat:.03f} s)"
    )

    new_files = []
    for n in list(range(num_segments)):
        start = n * segment_length
        end = start + segment_length - eighth_beat
        if n > 0:
            start -= eighth_beat
        print(f"\t{n:03d} splitting from {start:08.03f} s to {end:07.03f} s)")
        segment_midi = pretty_midi.PrettyMIDI(initial_tempo=target_tempo)
        instrument = pretty_midi.Instrument(
            program=midi_pm.instruments[0].program,
            name=f"{filename}_{int(start):04d}-{int(end):04d}",
        )

        # add notes from the original MIDI that fall within the current segment
        for note in midi_pm.instruments[0].notes:
            if start <= note.start < end:
                new_note = pretty_midi.Note(
                    velocity=note.velocity,
                    pitch=note.pitch,
                    start=note.start - start,
                    # end=min(note.end, end) - start,
                    end=note.end - start,
                )
                instrument.notes.append(new_note)

        # pad front of track to full bar for easier playback
        if n > 0:
            for note in instrument.notes:
                note.start += eighth_beat * 7
                note.end += eighth_beat * 7

        # write out
        segment_filename = os.path.join(
            output_dir, f"{filename}_{int(start):04d}-{int(end):04d}_n00.mid"
        )

        segment_midi.instruments.append(instrument)
        segment_midi.write(segment_filename)
        set_tempo(segment_filename, target_tempo)
        modify_end_of_track(segment_filename, segment_length, target_tempo)

        new_files.append(segment_filename)

    return new_files

In [None]:
def build_fs(dirs: List[str]) -> None:
    for dir in dirs:
        if os.path.exists(dir):
            i = 0
            for i, file in enumerate(os.listdir(dir)):
                os.remove(os.path.join(dir, file))
                i += 1
            print(f"cleaned {i} files out of folder: '{dir}'")
        else:
            os.mkdir(dir)
            print(f"created new folder: '{dir}'")

## go

In [None]:
tracks = ["from-60-db-test.mid"]

In [None]:
# segment files
segment_paths = []
for trackname in track(tracks, description="generating segments"):
    if trackname.endswith(".mid"):
        print(f"segmenting '{trackname}'")
        segment_paths.extend(
            segment_midi(
                os.path.join("files", trackname),
                os.path.join("files", "play"),
            )
        )

print(f"[green]segmentation complete, {len(segment_paths)} files generated")

In [None]:
MidiFile(segment_paths[3]).print_tracks()
MIDIPlayer(segment_paths[-1], 300)

## just augment

In [None]:
from rich.progress import (
    Progress,
    SpinnerColumn,
    TimeElapsedColumn,
    MofNCompleteColumn,
)
from itertools import product
from pretty_midi import PrettyMIDI, Note, Instrument
import os
from pathlib import Path

from typing import Dict

def augment_midi(trackname: str, files: List[str], output_path: str) -> List[str]:
    augmented_files = []

    p = Progress(
        SpinnerColumn(),
        *Progress.get_default_columns(),
        TimeElapsedColumn(),
        MofNCompleteColumn(),
        refresh_per_second=1,
    )
    task_a = p.add_task(f"augmenting {trackname}", total=len(files) * 96)

    with p:
        for segment_filename in files:
            transformations = [
                {"transpose": t, "shift": s} for t, s in product(range(12), range(8))
            ]
            for transformation in transformations:
                augmented_files.append(
                    transform(
                        segment_filename,
                        output_path,
                        int(trackname.split("-")[1]),
                        transformation,
                    )
                )
                p.update(task_a, advance=1)

    return augmented_files

def transform(file_path: str, out_dir: str, tempo: int, transformations: Dict, num_beats: int = 8) -> str:
    new_filename = f"{Path(file_path).stem}_t{transformations["transpose"]:02d}s{transformations["shift"]:02d}.mid"
    out_path = os.path.join(out_dir, new_filename)
    MidiFile(file_path).save(out_path) # in case transpose is 0

    if transformations["transpose"] != 0:
        t_midi = PrettyMIDI(initial_tempo=tempo)

        for instrument in PrettyMIDI(out_path).instruments:
            transposed_instrument = Instrument(program=instrument.program, name=new_filename[:-4])

            for note in instrument.notes:
                transposed_instrument.notes.append(
                    Note(
                        velocity=note.velocity,
                        pitch=note.pitch + int(transformations["transpose"]),
                        start=note.start,
                        end=note.end,
                    )
                )

            t_midi.instruments.append(transposed_instrument)

        t_midi.write(out_path)

    if transformations["shift"] != 0:
        s_midi = PrettyMIDI(initial_tempo=tempo)
        seconds_per_beat = 60 / tempo
        shift_seconds = transformations["shift"] * seconds_per_beat
        loop_point = (num_beats + 1) * seconds_per_beat

        for instrument in PrettyMIDI(out_path).instruments:
            shifted_instrument = Instrument(
                program=instrument.program, name=new_filename[:-4]
            )
            for note in instrument.notes:
                dur = note.end - note.start
                shifted_start = (note.start + shift_seconds) % loop_point
                shifted_end = shifted_start + dur

                if note.start + shift_seconds >= loop_point:
                    shifted_start += seconds_per_beat
                    shifted_end += seconds_per_beat

                shifted_instrument.notes.append(
                    Note(
                        velocity=note.velocity,
                        pitch=note.pitch,
                        start=shifted_start,
                        end=shifted_end
                    )
                )

            s_midi.instruments.append(shifted_instrument)

        s_midi.write(out_path)

    change_tempo(out_path, tempo)

    return out_path


def get_tempo(filename: str) -> int:
    return int(filename.split('-')[1])


def change_tempo(file_path: str, tempo: int):
    midi = mido.MidiFile(file_path)
    new_tempo = mido.bpm2tempo(tempo)
    new_message = mido.MetaMessage("set_tempo", tempo=new_tempo, time=0)
    tempo_added = False

    for track in midi.tracks:
        # remove existing set_tempo messages
        for msg in track:
            if msg.type == "set_tempo":
                track.remove(msg)

        # add new set_tempo message to the first track
        if not tempo_added:
            track.insert(0, new_message)
            tempo_added = True

    # if no tracks had a set_tempo message and no new one was added, add a new track with the tempo message
    if not tempo_added:
        new_track = mido.MidiTrack()
        new_track.append(new_message)
        midi.tracks.append(new_track)

    midi.save(file_path)

augmented_files = []
folder = os.path.join('..', 'data', 'datasets', 'test', 'play')
for filename in os.listdir(folder):
    augmented_files.extend(augment_midi(filename[:-4],[os.path.join(folder, filename)], os.path.join('..', 'data', 'datasets', 'test', 'train')))

print(f"[green bold]augmentation complete, {len(augmented_files)} files generated")