In [4]:
import os
import sys
import mido
import time
import numpy as np

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from src.utils import basename
from src.ml.specdiff.model import SpectrogramDiffusion, DEFAULT_CONFIG

In [5]:
tmp_dir = "outputs/embeddings"
midi_port = "LPK25"
DEFAULT_CONFIG["device"] = "cpu"
DEFAULT_CONFIG["encoder_weights_path"] = (
    "/Users/finlay/Documents/Programming/disklavier/src/ml/specdiff/note_encoder.bin"
)
model = SpectrogramDiffusion(DEFAULT_CONFIG, verbose=False)
record_length = 5
difference_threshold = 2.5

In [None]:
master_midi = mido.MidiFile(ticks_per_beat=220)
master_midi.add_track()

midi = mido.MidiFile(ticks_per_beat=220)
track = mido.MidiTrack()
track.name = "player_recording_00"

first_msg = True
num_embeddings = 0
last_embedding = None
embeddings = np.zeros((10, 768))  # only store last 10 embeddings
print("waiting for first note")
try:
    with mido.open_input(midi_port) as inport:
        for msg in inport:
            # init time tracking
            if first_msg:
                first_msg = False
                start_time = time.time()
                last_note_time = start_time
                now = start_time
                print("starting recording at", start_time)
            else:
                now = time.time()

            # update msg time
            msg.time = mido.second2tick(now - last_note_time, 220, mido.bpm2tempo(120))
            last_note_time = now
            track.append(msg)
            master_midi.tracks[0].append(msg)

            # save and embed when recording is done
            if now - start_time > record_length:
                # print(track)
                midi.tracks.append(track)
                filename = os.path.join(
                    tmp_dir, f"player_recording_{num_embeddings:02d}.mid"
                )
                midi.save(filename)

                print(f"total time is {now - start_time}")
                print(f"embedding {num_embeddings}")
                embedding = model.embed(filename)
                embeddings[num_embeddings % 10] = embedding
                print("embedding complete")

                if last_embedding is not None:
                    # calculate magnitude of difference between last embedding and current embedding
                    diff_mag = np.linalg.norm(last_embedding - embedding)
                    print(f"Difference magnitude: {diff_mag}")

                # reset tracking
                midi = mido.MidiFile(ticks_per_beat=220)
                track = mido.MidiTrack()
                track.name = f"player_recording_{num_embeddings}"
                last_embedding = embedding
                start_time = time.time()
                num_embeddings += 1
except Exception as e:
    print(e)


waiting for first note
starting recording at 1745949156.256033
total time is 5.839620113372803
embedding 0


  embeddings[num_embeddings % 10] = embedding


embedding complete
total time is 5.233791828155518
embedding 1
embedding complete
Difference magnitude: 2.4894564151763916
total time is 5.038543701171875
embedding 2
embedding complete
Difference magnitude: 2.8305609226226807
total time is 5.088855028152466
embedding 3
embedding complete
Difference magnitude: 2.5169053077697754
total time is 6.6554319858551025
embedding 4
embedding complete
Difference magnitude: 2.1539387702941895


KeyboardInterrupt: 

In [8]:
master_midi.save(os.path.join(tmp_dir, "master_midi.mid"))
master_midi.print_tracks()

=== Track 0
Message('note_on', channel=0, note=40, velocity=89, time=0)
Message('note_off', channel=0, note=40, velocity=127, time=100)
Message('note_on', channel=0, note=41, velocity=105, time=42)
Message('note_off', channel=0, note=41, velocity=127, time=95)
Message('note_on', channel=0, note=44, velocity=119, time=45)
Message('note_off', channel=0, note=44, velocity=127, time=86)
Message('note_on', channel=0, note=46, velocity=109, time=70)
Message('note_off', channel=0, note=46, velocity=127, time=99)
Message('note_on', channel=0, note=48, velocity=113, time=51)
Message('note_off', channel=0, note=48, velocity=127, time=109)
Message('note_on', channel=0, note=44, velocity=109, time=27)
Message('note_off', channel=0, note=44, velocity=127, time=79)
Message('note_on', channel=0, note=48, velocity=123, time=75)
Message('note_off', channel=0, note=48, velocity=127, time=292)
Message('note_on', channel=0, note=47, velocity=117, time=29)
Message('note_on', channel=0, note=44, velocity=1,