In [None]:
# install dependencies
!pip install -q --upgrade mido

[?25l[K     |██████▍                         | 10 kB 22.1 MB/s eta 0:00:01[K     |████████████▉                   | 20 kB 11.8 MB/s eta 0:00:01[K     |███████████████████▎            | 30 kB 9.6 MB/s eta 0:00:01[K     |█████████████████████████▋      | 40 kB 8.7 MB/s eta 0:00:01[K     |████████████████████████████████| 51 kB 3.3 MB/s 
[?25h

In [None]:
from mido import MidiFile, MidiTrack, Message, MetaMessage
from enum import IntEnum
from collections import deque

In [None]:
def tics_to_millis(tempo: int, ticks_per_beat: int, ticks: int) -> int:
  return int(round(tempo * ticks / (ticks_per_beat * 10000)))

class SpecialTokens(IntEnum):
  PADDING = 0
  START = 1
  END = 2

class Serializer:
  def __init__(self, max_power=0) -> None:
      self.max_power = max_power
      self.NOTE_RANGE = 128
      self.SONG_SPLIT_THRESH = 300 # split songs if >= 3s space
      self.SONG_SKIP_THRESH = 1000 # skip song if <= 10s length
      self.time_shifts = []
      self.chord_to_token = {}
      self.token_to_chord = {}

  def encode_spaces(self, time_shift: int) -> list:
    self.time_shifts.append(time_shift)
    spaces = deque()
    mask = 1
    power = 0
    start = len(SpecialTokens) + self.NOTE_RANGE * 2

    while time_shift >= mask:
      if time_shift & mask:
        self.max_power = max(power, self.max_power)
        spaces.appendleft(start + power)
      
      mask <<= 1
      power += 1
    
    return list(spaces)

  def tokenize(self, midi: MidiFile, do_not_skip=False) -> list:
    pretokenized = self.__pretokenize(midi, do_not_skip)
    
    return [self.__tokenize_single(tokenized) for tokenized in pretokenized]


  def __tokenize_chord(self, pushed_notes):
    if pushed_notes in self.chord_to_token:
      return self.chord_to_token[pushed_notes]
    
    token = len(SpecialTokens) + self.max_power + 1 + len(self.chord_to_token)
    self.chord_to_token[pushed_notes] = token
    self.token_to_chord[token] = pushed_notes
    return token


  def __tokenize_single(self, tokens: list) -> list:
    # figure out the time
    duration = 0
    pushed_notes = 0
    prev_pushed_notes = 0
    tokenization = [SpecialTokens.START.value]

    for token in tokens:
      if token == SpecialTokens.START.value or token == SpecialTokens.END.value:
        continue

      if token >= self.__space_start():
        power = token - self.__space_start()
        tokenization.append(len(SpecialTokens) + power)
        continue

      token -= len(SpecialTokens)
      on = token < self.NOTE_RANGE
      pitch = token % self.NOTE_RANGE
      powered = (1 << pitch)
      if on:
        pushed_notes = pushed_notes | powered # activate note
      elif pushed_notes & powered:
        pushed_notes -= powered # deactivate note

      if pushed_notes != prev_pushed_notes:
        tokenization.append(self.__tokenize_chord(pushed_notes))
        prev_pushed_notes = pushed_notes

    tokenization.append(SpecialTokens.END.value)
    return tokenization


  def __pretokenize(self, midi: MidiFile, do_not_skip=False) -> list:
    notes_on = {}
    for i in range(self.NOTE_RANGE):
      notes_on[i] = False

    out = [[SpecialTokens.START.value]]

    ppq = midi.ticks_per_beat
    tempo = midi.tracks[0][0].tempo
    time_shift = 0
    time_shift_global = 0
    for i, message in enumerate(midi.tracks[1]):
      if message.type == "track_name" or message.type == "program_change":
        continue
      
      time_shift += tics_to_millis(tempo, ppq, message.time)
      if message.type == "control_change":
        continue

      # skip initial pause
      if out[-1][-1] == SpecialTokens.START.value:
        time_shift = 0
      
      # split songs
      if time_shift >= self.SONG_SPLIT_THRESH:
        split = True
        # ensure no note was pressed
        for on in notes_on.values():
          if on:
            split = False
            break
        # split should happen
        if split:
          if time_shift_global <= self.SONG_SKIP_THRESH and not do_not_skip:
            # discard previous song if <= 10s
            out[-1] = [SpecialTokens.START.value]
          else:
            out[-1].append(SpecialTokens.END.value)
            out.append([SpecialTokens.START.value])
          time_shift = 0
          time_shift_global = 0

      # encode spaces
      if time_shift > 0:
        out[-1].extend(self.encode_spaces(time_shift))
        time_shift_global += time_shift
        time_shift = 0

      # encode note
      if message.type == "note_on":
        val = len(SpecialTokens) + message.note
        if message.velocity == 0:
          val += self.NOTE_RANGE
          notes_on[message.note] = False
        else:
          notes_on[message.note] = True
        out[-1].append(val)
    out[-1].append(SpecialTokens.END.value)

    # remove last song if too short
    if time_shift_global <= self.SONG_SKIP_THRESH and not do_not_skip:
      out = out[:-1]
    
    clean = []
    # trim final space
    for tokenized in out:
      index = -2
      while tokenized[index] >= self.__space_start():
        index -= 1
      clean.append(tokenized[:index + 1])
      clean[-1].append(SpecialTokens.END.value)
    return clean

  def __space_start(self):
    return len(SpecialTokens) + self.NOTE_RANGE * 2

  def __space_start_decode(self):
    return len(SpecialTokens)
  
  def __space_end_decode(self):
    return self.__space_start_decode() + self.max_power + 1

  def vocab_size(self) -> int:
    return len(SpecialTokens) + self.max_power + 1 + len(self.chord_to_token)

  def deserialize(self, tokenized: list) -> MidiFile:
    mid = MidiFile(ticks_per_beat=500)
    ctrl_track = MidiTrack()
    ctrl_track.append(MetaMessage('set_tempo', tempo=500000, time=0))
    ctrl_track.append(MetaMessage('end_of_track', time=1))
    mid.tracks.append(ctrl_track)

    notes_pressed = 0
    track = MidiTrack()
    time_shift = 0
    for token in tokenized:
      # prepare for note messages
      messages = []

      # skip padding
      if token == SpecialTokens.PADDING.value:
        continue

      # space
      if token >= self.__space_start_decode() and token < self.__space_end_decode():
        power = token - self.__space_start_decode()
        time_shift += 1 << power
        continue
      
      if token == SpecialTokens.START.value:
        message = Message('program_change', channel=0, program=0)

      elif token == SpecialTokens.END.value:
        message = MetaMessage('end_of_track')
      else:
        # note
        chord = self.token_to_chord[token]
        diff = chord ^ notes_pressed
        pitch = 1
        pitch_idx = 0

        while pitch <= diff:
          if pitch & diff:
            velocity = 64 if (pitch & chord) else 0
            messages.append(Message('note_on', channel=0, note=pitch_idx, velocity=velocity, time=0))
          pitch <<= 1
          pitch_idx += 1
        notes_pressed = chord

      if len(messages) > 0:
        messages[0].time = time_shift * 10
        track.extend(messages)
      else:
        message.time = time_shift * 10
        track.append(message)
      time_shift = 0
    mid.tracks.append(track)
    return mid