# **Code setup**

In [None]:
! pip install pretty_midi

Collecting pretty_midi
[?25l  Downloading https://files.pythonhosted.org/packages/bc/8e/63c6e39a7a64623a9cd6aec530070c70827f6f8f40deec938f323d7b1e15/pretty_midi-0.2.9.tar.gz (5.6MB)
[K     |████████████████████████████████| 5.6MB 6.8MB/s 
Collecting mido>=1.1.16
[?25l  Downloading https://files.pythonhosted.org/packages/20/0a/81beb587b1ae832ea6a1901dc7c6faa380e8dd154e0a862f0a9f3d2afab9/mido-1.2.9-py2.py3-none-any.whl (52kB)
[K     |████████████████████████████████| 61kB 8.7MB/s 
Building wheels for collected packages: pretty-midi
  Building wheel for pretty-midi (setup.py) ... [?25l[?25hdone
  Created wheel for pretty-midi: filename=pretty_midi-0.2.9-cp37-none-any.whl size=5591954 sha256=01eb090a608135ed7800a73bb4d7292722dc5e522f20e57307716067cf3c7b48
  Stored in directory: /root/.cache/pip/wheels/4c/a1/c6/b5697841db1112c6e5866d75a6b6bf1bef73b874782556ba66
Successfully built pretty-midi
Installing collected packages: mido, pretty-midi
Successfully installed mido-1.2.9 pretty-midi

In [None]:
# imports 
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from random import shuffle
from pretty_midi import Note, PrettyMIDI, Instrument, ControlChange
import six
import copy, pathlib
import os, time, datetime, random, copy
import argparse
import numpy as np
import math

In [None]:
# helper functions

def vectorize(sequence):
    """
    Converts a list of pretty_midi Note objects into a numpy array of
    dimension (n_notes x 4)
    """
    array = [[note.start, note.end, note.pitch, note.velocity] for
            note in sequence]
    return np.asarray(array)

def devectorize(note_array):
    """
    Converts a vectorized note sequence into a list of pretty_midi Note
    objects
    """
    return [Note(start = a[0], end = a[1], pitch=a[2],
        velocity=a[3]) for a in note_array.tolist()]


def one_hot(sequence, n_states):
    """
    Given a list of integers and the maximal number of unique values found
    in the list, return a one-hot encoded tensor of shape (m, n)
    where m is sequence length and n is n_states.
    """
    if torch.cuda.is_available():
        return torch.eye(n_states)[sequence,:].cuda()
    else:
        return torch.eye(n_states)[sequence,:]

def decode_one_hot(vector):
    '''
    Given a one-hot encoded vector, return the non-zero index
    '''
    return vector.nonzero().item()

def prepare_batches(sequences, batch_size):
    """
    Splits a list of sequences into batches of a fixed size. Each sequence yields an input sequence
    and a target sequence, with the latter one time step ahead. For example, the sequence "to be or not
    to be" gives an input sequence of "to be or not to b" and a target sequence of "o be or not to be."
    """
    n_sequences = len(sequences)
    for i in range(0, n_sequences, batch_size):
        batch = sequences[i:i+batch_size]
	#needs to be in sorted order for packing batches to work
        batch = sorted(batch, key = len, reverse=True)
        input_sequences, target_sequences = [], []

        for sequence in batch:
            input_sequences.append(sequence[:-1])
            target_sequences.append(sequence[1:])

        yield input_sequences, target_sequences

def clones(module, N):
    "Clone N identical layers of a module"
    return torch.nn.ModuleList([copy.deepcopy(module) for i in range(N)])

def d(tensor=None):
    if tensor is None:
        return 'cuda' if torch.cuda.is_available() else 'cpu'
    return 'cuda' if tensor.is_cuda else 'cpu'

def write_midi(note_sequence, output_dir, filename):

    #make output directory
    pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

    #generate midi
    midi = PrettyMIDI()
    piano_track = Instrument(program=0, is_drum=False, name=filename)
    piano_track.notes = note_sequence
    midi.instruments.append(piano_track)
    output_name = output_dir + f"{filename}.midi"
    midi.write(output_name)

def sample(model, sample_length, prime_sequence=[], temperature=1):
    """
    Generate a MIDI event sequence of a fixed length by randomly sampling from a model's distribution of sequences. Optionally, "seed" the sequence with a prime. A well-trained model will create music that responds to the prime and develops upon it.
    """
    #deactivate training mode
    model.eval()
    if len(prime_sequence) == 0:
        #if no prime is provided, randomly select a starting event
        input_sequence = [np.random.randint(model.n_tokens)]
    else:
        input_sequence = prime_sequence.copy()

    #add singleton dimension for the batch
    input_tensor = torch.LongTensor(input_sequence).unsqueeze(0)

    for i in range(sample_length):
        #select probabilities of *next* token
        out = model(input_tensor)[0, -1, :]
        #out is a 1d tensor of shape (n_tokens)
        probs = F.softmax(out / temperature, dim=0)
        #sample prob distribution for next character
        c = torch.multinomial(probs,1)
        input_tensor = torch.cat([input_tensor[:,1:], c[None]], dim=1)
        input_sequence.append(c.item())

    return input_sequence

In [None]:
# custom help code for training
class Accuracy(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, prediction, target, mask=None, token_dim=-1,
            sequence_dim=-2):

        #normalize by token classes and guess most probable sequence
        prediction = F.softmax(prediction, token_dim)\
                .argmax(sequence_dim)


        scores = (prediction == target)
        n_padded = 0
        if mask is not None:
            n_padded = (mask == 0).sum()
        return scores.sum() / float(scores.numel() - n_padded)


def smooth_cross_entropy(prediction, target, eps=0.1,
        ignore_index=0):


    mask = (target == ignore_index).unsqueeze(-1)

    prediction = prediction.transpose(1,2)
    n_classes = prediction.shape[-1]
    #one hot encode target
    p = F.one_hot(target, n_classes)
    #uniform distribution probability
    u = 1.0 / n_classes
    p_prime = (1.0 - eps) * p + eps * u
    #ignore padding indices
    p_prime = p_prime.masked_fill(mask, 0)
    #cross entropy
    h = -torch.sum(p_prime * F.log_softmax(prediction, -1))
    #mean reduction
    n_items = torch.sum(target != ignore_index)

    return h / n_items


class TFSchedule:
    """
    From https://www.tensorflow.org/tutorials/text/transformer. Wrapper for Optimizer, gradually increases learning rate for a warmup period before learning rate decay sets in.
    """

    def __init__(self, optimizer, d_model, warmup_steps=4000):

        self.opt = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps

        self._step = 0
        self._rate = 0

    def step(self):

        self._step += 1
        rate = self.rate()
        for p in self.opt.param_groups:
            p['lr'] = rate

        self._rate = rate
        self.opt.step()


    def rate(self, step=None):

        if step is None:
            step = self._step

        arg1 = step ** (-0.5)
        arg2 = step * (self.warmup_steps ** -1.5)

        return self.d_model ** (-0.5) * min(arg1, arg2)

In [None]:
# sequence encoding code
class SequenceEncoderError(Exception):
    pass

class SequenceEncoder():
    """
    Converts sequences of Midi Notes to sequences of events under the following 
    representation:
    - 128 NOTE-ON events (for each of the 128 MIDI pitches, starts a new note)
    - 128 NOTE-OFF events (likewise. Ends 
    - (1000 / t) TIME-SHIFT events (each moves the time step forward by increments of 
      t ms up to 1 second
    - v VELOCITY events (each one changes the velocity applied to all subsequent notes
      until another velocity event occurs)
    Includes functions to cast a sequence of Midi Notes to a numeric list of 
    possible events and to one-hot encode a numeric sequence as a Pytorch tensor.
    """

    def __init__(self, n_time_shift_events, n_velocity_events,
            sequences_per_update=50000, min_events=33, max_events=513):
        self.n_time_shift_events = n_time_shift_events
        self.n_events = 256 + n_time_shift_events + n_velocity_events
        self.timestep = 1 / n_time_shift_events
        self.velocity_bin_size = 128 // n_velocity_events
        self.sequences_per_update = sequences_per_update
        self.min_events = min_events
        self.max_events = max_events

    def encode_sequences(self, sample_sequences):
        """
        Converts each sample note sequence into an "event" sequence, a list of integers
        0 through N-1 where N is the total number of events in the encoder's
        representation.
        """
        event_sequences = []
        #count how many sequences are discarded/truncated due to length
        short_count, long_count = 0,0
        n_sequences = len(sample_sequences)
        for i in range(n_sequences):
            if not (i % self.sequences_per_update):
                print("{:,} / {:,} sequences encoded".\
                        format(i, n_sequences))
            event_sequence = []
            event_timestamps = []
            #attempt at efficiency gain: only add a velocity event if it's different
            #from current velocity...this is tricky if two notes played at the
            #same time have different velocity
            #current_velocity = 0
            for note in sample_sequences[i]:
                #extract start/end time, pitch and velocity
                t0, t1, p, v = note
                event_timestamps.append((t0, "VELOCITY", v))
                #if v != current_velocity:
                #    event_timestamps.append((t0, "VELOCITY", v))
                #    current_velocity = v
                event_timestamps.append((t0, "NOTE_ON", p))
                event_timestamps.append((t1, "NOTE_OFF", p))

            # sort events by timestamp
            event_timestamps = sorted(event_timestamps, key = lambda x: x[0])
            current_time = 0
            max_timeshift = self.n_time_shift_events
            #this loop encodes timeshifts as numbers
            #consider turning this into a function to help readability
            for timestamp in event_timestamps:
                #capture a shift in absolute time
                if timestamp[0] != current_time:
                    #convert to relative time and convert to number of quantized timesteps
                    timeshift = (timestamp[0] *  self.n_time_shift_events) - \
                            (current_time * self.n_time_shift_events)
                    #this is hacky but sue me
                    timeshift = int(timeshift + .1)
                    timeshift_events = []
                    #aggregate pauses longer than one second, as necessary
                    while timeshift > max_timeshift:
                        timeshift_events.append(
                                self.event_to_number("TIME_SHIFT", max_timeshift))
                        timeshift -= max_timeshift
                    #add timeshift (mod 1 second) as an event
                    timeshift_events.append(
                            self.event_to_number("TIME_SHIFT", timeshift))
                    event_sequence.extend(timeshift_events)
                    
                    #add the other events: NOTE_ON, NOTE_OFF, VELOCITY
                    current_time = timestamp[0]
                event_sequence.append(
                        self.event_to_number(timestamp[1], timestamp[2]))

            #check if sequence is too short to keep
            if self.min_events is not None:
                if len(event_sequence) < self.min_events:
                    short_count += 1
                    continue
            #truncate sequence if necessary
            if self.max_events is not None:
                if len(event_sequence) > self.max_events:
                    event_sequence = event_sequence[:self.max_events]
                    long_count += 1

            event_sequences.append(event_sequence)

        if short_count > 0:
            print(f"{short_count} sequences discarded due to brevity")
        if long_count > 0:
            print(f"{long_count} sequences truncated due to excessive length.")

        return event_sequences

                
    def event_to_number(self, event, value):
        """
        Encode an event/value pair as a number 0-N-1
        where N is the number of unique events in the Encoder's representation.
        """
        if event == "NOTE_ON":
            return value
        elif event == "NOTE_OFF":
            return value + 128
        elif event == "TIME_SHIFT":
            #subtract one to fit to zero-index convention
            #i.e. the number 256 corresponds to the smallest possible timestep
            #which is non-zero...!
            return value + 256 - 1
        elif event == "VELOCITY":
            #convert to bins
            v_bin = (value - 1) // self.velocity_bin_size
            return v_bin + 256 + self.n_time_shift_events
        else:
            raise SequenceEncoderError("Event type {} not recognized".format(event))

    def number_to_event(self, number):
        number = int(number)
        if number < 0 or number >= self.n_events:
            raise SequenceEncoderError("Number {} out of range")

        if number < 128:
            event = "NOTE_ON", number
        elif 128 <= number < 256:
            event = "NOTE_OFF", number - 128
        elif 256 <= number < 256 + self.n_time_shift_events:
            event = "TIME_SHIFT", number + 1 - 256
        else:
            bin_number = number - 256 - self.n_time_shift_events
            event = "VELOCITY", (bin_number * self.velocity_bin_size) + 1
        return event

    def decode_sequences(self, encoded_sequences):
        """
        Given a list of encoded sequences, decode each of them and return a list of pretty_midi Note sequences.
        """
        note_sequences = []
        for encoded_sequence in encoded_sequences:
            note_sequences.append(self.decode_sequence(encoded_sequence))

        return note_sequences

    def decode_sequence(self, encoded_sequence, stuck_note_duration=None, keep_ghosts=False, verbose=False):
        """
        Takes in an encoded event sequence (sparse numerical representation) and transforms it back into a pretty_midi Note sequence. Randomly-generated encoded sequences, such as produced by the generation script, can have some unusual traits such as notes without a provided end time. Contains logic to handle these pathological notes.
        Args:
            encoded_sequence (list): List of events encoded as integers
            stuck_note_duration (int or None): if defined, for recovered notes missing an endtime, give them a fixed duration (as number of seconds held)
            keep_ghosts (bool): if true, when the decoding algorithm recovers notes with an end time preceding their start time, keep them by swapping start and end. If false, discard the "ghost" notes
            verbose (bool): If true, print results on how many stuck notes and ghost notes are detected.
        """
        events = []
        for num in encoded_sequence:
            events.append(self.number_to_event(num))
        #list of pseudonotes = {'start':x, 'pitch':something, 'velocity':something}
        notes = []
        #on the second pass, add in end time
        note_ons = []
        note_offs = []
        global_time = 0
        current_velocity = 0
        for event, value in events:
            #check event type
            if event == "TIME_SHIFT":
                global_time += 0.008 * value
                global_time = round(global_time, 5)

            elif event == "VELOCITY":
                current_velocity = value
            
            elif event == "NOTE_OFF":
                #eventually we'll sort this by timestamp and work thru
                note_offs.append({"pitch": value, "end": global_time})
            
            elif event == "NOTE_ON":
                #it's a NOTE_ON!
                #value is pitch 
                note_ons.append({"start": global_time, 
                    "pitch": value, "velocity": current_velocity})
            else:
                raise SequenceEncoderError("you fool!")

        #keep a count of notes that are missing an end time (stuck notes)
        #----default behavior is to ignore them. 
        stuck_notes = 0
        
        #keep a count of notes assigned end times *before* their start times (ghost notes)
        #----default behavior is to ignore them
        ghost_notes = 0


        #Zip up notes with corresponding note-off events
        while len(note_ons) > 0:
            note_on = note_ons[0]
            pitch = note_on['pitch']
            #this assumes everything is sorted nicely!
            note_off = next((n for n in note_offs if n['pitch'] == pitch), None)
            if note_off == None:
                stuck_notes += 1
                if stuck_note_duration is None:
                    note_ons.remove(note_on)
                    continue
                else:
                    note_off = {"pitch": pitch, "end": note_on['start'] + stuck_note_duration}
            else:
                note_offs.remove(note_off)

            if note_off['end'] < note_on['start']:
                ghost_notes += 1
                if keep_ghosts:
                    #reverse start and end (and see what happens...!)
                    new_end = note_on['start']
                    new_start = note_off['end']
                    note_on['start'] = new_start
                    note_off['end'] = new_end
                else:
                    note_ons.remove(note_on)
                    continue

            note = Note(start = note_on['start'], end = note_off['end'],
                    pitch = pitch, velocity = note_on['velocity'])
            notes.append(note)
            note_ons.remove(note_on)

        if verbose:
            print(f"{stuck_notes} notes missing an end-time...")
            print(f"{ghost_notes} had an end-time precede their start-time")

        return notes


In [None]:
# preprocessing pipeline

class PreprocessingError(Exception):
    pass

class PreprocessingPipeline():
    #set a random seed
    SEED = 1811
    """
    Pipeline to convert MIDI files to cleaned Piano Midi Note Sequences, split into 
    a more manageable length.
    Applies any sustain pedal activity to extend note lengths. Optionally augments
    the data by transposing pitch and/or stretching sample speed. Optionally quantizes
    timing and/or dynamics into smaller bins.
    Attributes:
        self.split_samples (dict of lists): when the pipeline is run, has two keys, "training" and "validation," each holding a list of split MIDI note sequences.
        self.encoded_sequences (dict of lists): Keys are "training" and "validation." Each holds a list of encoded event sequences, a sparse numeric representation of a MIDI sample.
    """
    def __init__(self, input_dir, stretch_factors = [0.95, 0.975, 1, 1.025, 1.05],
            split_size = 30, sampling_rate = 125, n_velocity_bins = 32,
            transpositions = range(-3,4), training_val_split = 0.9, 
            max_encoded_length = 512, min_encoded_length = 33):
        self.input_dir = input_dir
        self.split_samples = dict()
        self.stretch_factors = stretch_factors
        #size (in seconds) in which to split midi samples
        self.split_size = split_size
        #In hertz (beats per second), quantize sample timings to this discrete frequency
        #So a sampling rate of 125 hz means a smallest time steps of 8 ms
        self.sampling_rate = sampling_rate
        #Quantize sample dynamics (Velocity 1-127) to a smaller number of bins
        #this should be an *integer* dividing 128 cleanly: 2,4,8,16,32,64, or 128. 
        self.n_velocity_bins = n_velocity_bins
        self.transpositions = transpositions
        
        #Fraction of raw MIDI data that goes to the training set
        #the remainder goes to validat
        self.training_val_split = training_val_split

        self.encoder = SequenceEncoder(n_time_shift_events = sampling_rate,
                n_velocity_events = n_velocity_bins, 
                min_events = min_encoded_length,
                max_events = max_encoded_length)
        self.encoded_sequences = dict()

        random.seed(PreprocessingPipeline.SEED)

        """
        Args:
            input_dir (str): path to input directory. All .midi or .mid files in this directory will get processed.
            stretch_factors (list of float): List of constants by which note end times and start times will be multiplied. A way to augment data.
            split_size (int): Max length, in seconds, of samples into which longer MIDI note sequences are split.
            sampling_rate (int): How many subdivisions of 1,000 milliseconds to quantize note timings into. E.g. a sampling rate of 100 will mean end and start times are rounded to the nearest 0.01 second.
            n_velocity_bins (int): Quantize 128 Midi velocities (amplitudes) into this many bins: e.g. 32 velocity bins mean note velocities are rounded to the nearest multiple of 4.
            transpositions (iterator of ints): Transpose note pitches up/down by intervals (number of half steps) in this iterator. Augments a dataset with transposed copies.
            training_val_split (float): Number between 0 and 1 defining the proportion of raw data going to the training set. The rest goes to validation.
            max_encoded_length (int): Truncate encoded samples containing more
            events than this number.
            min_encoded_length (int): Discard encoded samples containing fewer events than this number.
        """


    def run(self):
        """
        Main pipeline call...parse midis, split into test and validation sets,
        augment, quantize, sample, and encode as event sequences. 
        """
        midis = self.parse_files(chdir=True) 
        total_time = sum([m.get_end_time() for m in midis])
        print("\n{} midis read, or {:.1f} minutes of music"\
                .format(len(midis), total_time/60))

        note_sequences = self.get_note_sequences(midis)
        del midis
        #vectorize note sequences
        note_sequences = [vectorize(ns) for ns in note_sequences]
        print("{} note sequences extracted\n".format(len(note_sequences)))
        self.note_sequences = self.partition(note_sequences)
        for mode, sequences in self.note_sequences.items():
            print(f"Processing {mode} data...")
            print(f"{len(sequences):,} note sequences")
            if mode == "training":
                sequences = self.stretch_note_sequences(sequences)
                print(f"{len(sequences):,} stretched note sequences")
            samples = self.split_sequences(sequences)
            self.quantize(samples)
            print(f"{len(samples):,} quantized, split samples")
            if mode == "training":
                samples = self.transpose_samples(samples)
                print(f"{len(samples):,} transposed samples")
            self.split_samples[mode] = samples
            self.encoded_sequences[mode] = self.encoder.encode_sequences(samples)
            print(f"Encoded {mode} sequences!\n")

    def parse_files(self, chdir=False):
        """
        Recursively parse all MIDI files in a given directory to 
        PrettyMidi objects.
        """
        if chdir: 
            home_dir = os.getcwd()
            os.chdir(self.input_dir)

        pretty_midis = []
        folders = [d for d in os.listdir(os.getcwd()) if os.path.isdir(d)]
        if len(folders) > 0:
            for d in folders:
                os.chdir(d)
                pretty_midis += self.parse_files()
                os.chdir("..")
        midis = [f for f in os.listdir(os.getcwd()) if \
                (f.endswith(".mid") or f.endswith("midi"))]
        print(f"Parsing {len(midis)} midi files in {os.getcwd()}...")
        for m in midis:
            with open(m, "rb") as f:
                try:
                    midi_str = six.BytesIO(f.read())
                    pretty_midis.append(pretty_midi.PrettyMIDI(midi_str))
                    #print("Successfully parsed {}".format(m))
                except:
                    print("Could not parse {}".format(m))
        if chdir:
            os.chdir(home_dir)

        return pretty_midis

    def get_note_sequences(self, midis):
        """
        Given a list of PrettyMidi objects, extract the Piano track as a list of 
        Note objects. Calls the "apply_sustain" method to extract the sustain pedal
        control changes.
        """

        note_sequences = []
        for m in midis:
            if m.instruments[0].program == 0:
                piano_data = m.instruments[0]
            else:
                #todo: write logic to safely catch if there are non piano instruments,
                #or extract the piano midi if it exists
                raise PreprocessingError("Non-piano midi detected")
            note_sequence = self.apply_sustain(piano_data)
            note_sequence = sorted(note_sequence, key = lambda x: (x.start, x.pitch))
            note_sequences.append(note_sequence)

        return note_sequences



    def apply_sustain(self, piano_data):
        """
        While the sustain pedal is applied during a midi, extend the length of all 
        notes to the beginning of the next note of the same pitch or to 
        the end of the sustain. Returns a midi notes sequence.
        """
        _SUSTAIN_ON = 0
        _SUSTAIN_OFF = 1
        _NOTE_ON = 2
        _NOTE_OFF = 3
 
        notes = copy.deepcopy(piano_data.notes)
        control_changes = piano_data.control_changes
        #sequence of SUSTAIN_ON, SUSTAIN_OFF, NOTE_ON, and NOTE_OFF actions
        first_sustain_control = next((c for c in control_changes if c.number == 64),
                ControlChange(number=64, value=0, time=0))

        if first_sustain_control.value >= 64:
            sustain_position = _SUSTAIN_ON
        else:
            sustain_position = _SUSTAIN_OFF
        #if for some reason pedal was not touched...
        action_sequence = [(first_sustain_control.time, sustain_position, None)]
        #delete this please
        cleaned_controls = []
        for c in control_changes:
            #Ignoring the sostenuto and damper pedals due to complications
            if sustain_position == _SUSTAIN_ON:
                if c.value >= 64:
                    #another SUSTAIN_ON
                    continue
                else:
                    sustain_position = _SUSTAIN_OFF
            else:
                #look for the next on signal
                if c.value < 64:
                    #another SUSTAIN_OFF
                    continue
                else:
                    sustain_position = _SUSTAIN_ON
            action_sequence.append((c.time, sustain_position, None))
            cleaned_controls.append((c.time, sustain_position))
    
        action_sequence.extend([(note.start, _NOTE_ON, note) for note in notes])
        action_sequence.extend([(note.end, _NOTE_OFF, note) for note in notes])
        #sort actions by time and type
    
        action_sequence = sorted(action_sequence, key = lambda x: (x[0], x[1]))
        live_notes = []
        sustain = False
        for action in action_sequence:
            if action[1] == _SUSTAIN_ON:
                sustain = True
            elif action[1] == _SUSTAIN_OFF:
                #find when the sustain pedal is released
                off_time = action[0]
                for note in live_notes:
                    if note.end < off_time:
                        #shift the end of the note to when the pedal is released
                        note.end = off_time
                        live_notes.remove(note)
                sustain = False
            elif action[1] == _NOTE_ON:
                current_note = action[2]
                if sustain:
                    for note in live_notes:
                        # if there are live notes of the same pitch being held, kill 'em
                        if current_note.pitch == note.pitch:
                            note.end = current_note.start
                            live_notes.remove(note)
                live_notes.append(current_note)
            else:
                if sustain == True:
                    continue
                else:
                    note = action[2]
                    try:
                        live_notes.remove(note)
                    except ValueError:
                        print("***Unexpected note sequence...possible duplicate?")
                        pass
        return notes

    def partition(self, sequences):
       """
       Partition a list of Note sequences into a training set and validation set.
       Returns a dictionary {"training": training_data, "validation": validation_data}
       """
       partitioned_sequences = {}
       random.shuffle(sequences)

       n_training = int(len(sequences) * self.training_val_split)
       partitioned_sequences['training'] = sequences[:n_training]
       partitioned_sequences['validation'] = sequences[n_training:]

       return partitioned_sequences

    def stretch_note_sequences(self, note_sequences):
        """
        Stretches tempo (note start and end time) for each sequence in a given list
        by each of the pipeline's stretch factors. Returns a list of Note sequences.
        """
        stretched_note_sequences = []
        for note_sequence in note_sequences:
            for factor in self.stretch_factors:
                if factor == 1:
                    stretched_note_sequences.append(note_sequence)
                    continue
                stretched_sequence = np.copy(note_sequence)
                #stretch note start time
                stretched_sequence[:,0] *= factor
                #stretch note end time
                stretched_sequence[:,1] *= factor
                stretched_note_sequences.append(stretched_sequence)

        return stretched_note_sequences


    def split_sequences(self, sequences):
        """
        Given a list of Note sequences, splits them into samples no longer than 
        a given length. Returns a list of split samples.
        """

        samples = []
        if len(sequences) == 0:
            raise PreprocessingError("No note sequences available to split")

        for note_sequence in sequences:
            sample_length = 0
            sample = []
            i = 0
            while i < len(note_sequence):
                note = np.copy(note_sequence[i])
                if sample_length == 0:
                    sample_start = note[0]
                    if note[1] > self.split_size + sample_start:
                        #prevent case of a zero-length sample
                        #print(f"***Current note has length of more than {self.split_size} seconds...reducing duration")
                        note[1] = sample_start + self.split_size
                    sample.append(note)
                    sample_length = self.split_size
                else:
                    if note[1] <= sample_start + self.split_size:
                        sample.append(note)
                        if note[1] > sample_start + sample_length:
                            sample_length = note[1] - sample_start
                    else:
                        samples.append(np.asarray(sample))
                        #sample start should begin with the beginning of the
                        #*next* note, how do I handle this...
                        sample_length = 0
                        sample = []
                i += 1
        return samples

    def quantize(self, samples):
        """
        Quantize timing and dynamics in a Note sample in place. This converts continuous
        time to a discrete, encodable quantity and simplifies input for the model.
        Quantizes note start/ends to a smallest perceptible timestep (~8ms) and note
        velocities to a few audibly distinct bins (around 32).
        """
        #define smallest timestep (in seconds)
        try:
            timestep = 1 / self.sampling_rate
        except ZeroDivisionError:
            timestep = 0
        #define smallest dynamics increment
        try:
            velocity_step = 128 // self.n_velocity_bins
        except ZeroDivisionError:
            velocity_step = 0
        for sample in samples:
            sample_start_time = next((note[0] for note in sample), 0)
            for note in sample:
                #reshift note start and end times to begin at zero
                note[0] -= sample_start_time
                note[1] -= sample_start_time
                #delete this 
                if note[0] < 0 or note[1] < 0:
                    raise PreprocessingError
                if timestep:
                    #quantize timing
                    note[0] = (note[0] * self.sampling_rate) // 1 * timestep
                    note[1] = (note[1] * self.sampling_rate) // 1 * timestep
                if velocity_step:
                    #quantize dynamics
                    #smallest velocity is 1 (otherwise we can't hear it!)
                    note[3] = (note[3] // velocity_step *\
                            velocity_step) + 1

    def transpose_samples(self, samples):
        """
        Transposes the pitch of a sample note by note according to a list of intervals.
        """
        transposed_samples = []
        for sample in samples:
            for transposition in self.transpositions:
                if transposition == 0:
                    transposed_samples.append(sample)
                    continue
                transposed_sample = np.copy(sample)
                #shift pitches in sample by transposition
                transposed_sample[:,2] += transposition
                #should I adjust pitches that fall out of the range of 
                #a piano's 88 keys? going to be pretty uncommon.
                transposed_samples.append(transposed_sample)

        return transposed_samples

In [None]:
# attention code

class AttentionError(Exception):
    pass

class MultiheadedAttention(nn.Module):
    """
    Narrow multiheaded attention. Each attention head inspects a 
    fraction of the embedding space and expresses attention vectors for each sequence position as a weighted average of all (earlier) positions.
    """

    def __init__(self, d_model, heads=8, dropout=0.1, relative_pos=True):

        super().__init__()
        if d_model % heads != 0:
            raise AttentionError("Number of heads does not divide model dimension")
        self.d_model = d_model
        self.heads = heads
        s = d_model // heads
        self.linears = torch.nn.ModuleList([nn.Linear(s, s, bias=False) for i in range(3)])
        self.recombine_heads = nn.Linear(heads * s, d_model)
        self.dropout = nn.Dropout(p=dropout)
        self.max_length = 1024
        #relative positional embeddings
        self.relative_pos = relative_pos
        if relative_pos:
            self.Er = torch.randn([heads, self.max_length, s],
                    device=d())
        else:
            self.Er = None

    def forward(self, x, mask):
        #batch size, sequence length, embedding dimension
        b, t, e = x.size()
        h = self.heads
        #each head inspects a fraction of the embedded space
        #head dimension
        s = e // h
        #start index of position embedding
        embedding_start = self.max_length - t
        x = x.view(b,t,h,s)
        queries, keys, values = [w(x).transpose(1,2)
                for w, x in zip(self.linears, (x,x,x))]
        if self.relative_pos:
            #apply same position embeddings across the batch
            #Is it possible to apply positional self-attention over
            #only half of all relative distances?
            Er  = self.Er[:, embedding_start:, :].unsqueeze(0)
            QEr = torch.matmul(queries, Er.transpose(-1,-2))
            QEr = self._mask_positions(QEr)
            #Get relative position attention scores
            #combine batch with head dimension
            SRel = self._skew(QEr).contiguous().view(b*h, t, t)
        else:
            SRel = torch.zeros([b*h, t, t], device=d())
        queries, keys, values = map(lambda x: x.contiguous()\
                .view(b*h, t, s), (queries, keys, values))
        #Compute scaled dot-product self-attention
        #scale pre-matrix multiplication   
        queries = queries / (e ** (1/4))
        keys    = keys / (e ** (1/4))

        scores = torch.bmm(queries, keys.transpose(1, 2))
        scores = scores + SRel
        #(b*h, t, t)

        subsequent_mask = torch.triu(torch.ones(1, t, t, device=d()),
                1)
        scores = scores.masked_fill(subsequent_mask == 1, -1e9)
        if mask is not None:
            mask = mask.repeat_interleave(h, 0)
            wtf = (mask == 0).nonzero().transpose(0,1)
            scores[wtf[0], wtf[1], :] = -1e9

        
        #Convert scores to probabilities
        attn_probs = F.softmax(scores, dim=2)
        attn_probs = self.dropout(attn_probs)
        #use attention to get a weighted average of values
        out = torch.bmm(attn_probs, values).view(b, h, t, s)
        #transpose and recombine attention heads
        out = out.transpose(1, 2).contiguous().view(b, t, s * h)
        #last linear layer of weights
        return self.recombine_heads(out)


    def _mask_positions(self, qe):
        #QEr is a matrix of queries (absolute position) dot distance embeddings (relative pos).
        #Mask out invalid relative positions: e.g. if sequence length is L, the query at
        #L-1 can only attend to distance r = 0 (no looking backward).
        L = qe.shape[-1]
        mask = torch.triu(torch.ones(L, L, device=d()), 1).flip(1)
        return qe.masked_fill((mask == 1), 0)

    def _skew(self, qe):
        #pad a column of zeros on the left
        padded_qe = F.pad(qe, [1,0])
        s = padded_qe.shape
        padded_qe = padded_qe.view(s[0], s[1], s[3], s[2])
        #take out first (padded) row
        return padded_qe[:,:,1:,:]

In [None]:
# model code
class MusicTransformerError(Exception):
    pass

class MusicTransformer(nn.Module):
    """Generative, autoregressive transformer model. Train on a 
    dataset of encoded musical sequences."""

    def __init__(self, n_tokens, seq_length=None, d_model=64,
            n_heads=4, depth=2, d_feedforward=512, dropout=0.1,
            positional_encoding=False, relative_pos=True):
        """
        Args:
            n_tokens: number of commands/states in encoded musical sequence
            seq_length: length of (padded) input/target sequences
            d_model: dimensionality of embedded sequences
            n_heads: number of attention heads
            depth: number of stacked transformer layers
            d_feedforward: dimensionality of dense sublayer 
            dropout: probability of dropout in dropout sublayer
            relative_pos: (bool) if True, use relative positional embeddings
        """
        super().__init__()
        #number of commands in an encoded musical sequence
        self.n_tokens = n_tokens
        #embedding layer
        self.d_model = d_model
        self.embed = SequenceEmbedding(n_tokens, d_model)
        #positional encoding layer
        self.positional_encoding = positional_encoding
        if self.positional_encoding:
            pos = torch.zeros(5000, d_model)
            position = torch.arange(5000).unsqueeze(1)
            #geometric progression of wave lengths
            div_term = torch.exp(torch.arange(0.0, d_model, 2) * \
                            - (math.log(10000.0) / d_model))
	    #even positions
            pos[0:, 0::2] = torch.sin(position * div_term)
            #odd positions
            pos[0:, 1::2] = torch.cos(position * div_term)
            #batch dimension
            pos = pos.unsqueeze(0)
            #move to GPU if needed
            pos = pos.to(d())
            self.register_buffer('pos', pos)
        else:
            if seq_length == None:
                raise MusicTransformerError("seq_length not provided for positional embeddings")
            self.pos = nn.Embedding(seq_length, d_model)
        #last layer, outputs logits of next token in sequence
        self.to_scores = nn.Linear(d_model, n_tokens)
        self.layers = clones(DecoderLayer(d_model, n_heads,
            d_feedforward, dropout, relative_pos), depth)
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, x, mask=None):
        x = self.embed(x)
        b,t,e = x.size()
        if self.positional_encoding:
            positions = self.pos[:, :t, :]
        else:
            positions = self.pos(torch.arange(t, 
                device=d()))[None, :, :].expand(b, t, e)
        x = x + positions
        #another dropout layer here?
        #pass input batch and mask through layers
        for layer in self.layers:
            x  = layer(x, mask)
        #one last normalization for good measure
        z = self.norm(x)
        return self.to_scores(z)

class DecoderLayer(nn.Module):

    def __init__(self, size, n_heads, d_feedforward, dropout,
            relative_pos):

        super().__init__()
        self.self_attn = MultiheadedAttention(size, n_heads,
                dropout, relative_pos)
        self.feed_forward = PositionwiseFeedForward(size, d_feedforward, dropout)
        self.size = size
        #normalize over mean/std of embedding dimension
        self.norm1 = nn.LayerNorm(size)
        self.norm2 = nn.LayerNorm(size)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)


    def forward(self, x, mask):
        #perform masked attention on input
        #masked so queries cannot attend to subsequent keys
        #Pass through sublayers of attention and feedforward.
        #Apply dropout to sublayer output, add it to input, and norm.
        attn = self.self_attn(x, mask)
        x = x + self.dropout1(attn)
        x = self.norm1(x)

        ff = self.feed_forward(x)
        x = x + self.dropout2(ff)
        x = self.norm2(x)

        return x

class PositionwiseFeedForward(nn.Module):

    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))

class SequenceEmbedding(nn.Module):
    """
    Standard embedding, scaled by the sqrt of model's hidden state size
    """
    def __init__(self, vocab_size, model_size):
        super().__init__()
        self.d_model = model_size
        self.emb = nn.Embedding(vocab_size, model_size)

    def forward(self, x):
        return self.emb(x) * math.sqrt(self.d_model)

In [None]:
# train code
def batch_to_tensors(batch, n_tokens, max_length):
    """
    Make input, input mask, and target tensors for a batch of seqa batch of sequences.
    """
    input_sequences, target_sequences = batch
    sequence_lengths = [len(s) for s in input_sequences]
    batch_size = len(input_sequences)

    x = torch.zeros(batch_size, max_length, dtype=torch.long)
    #padding element
    y = torch.zeros(batch_size, max_length, dtype=torch.long)


    for i, sequence in enumerate(input_sequences):
        seq_length = sequence_lengths[i]
        #copy over input sequence data with zero-padding
        #cast to long to be embedded into model's hidden dimension
        x[i, :seq_length] = torch.Tensor(sequence).unsqueeze(0)
    
    x_mask = (x != 0)
    x_mask = x_mask.type(torch.uint8)

    for i, sequence in enumerate(target_sequences):
        seq_length = sequence_lengths[i]
        y[i, :seq_length] = torch.Tensor(sequence).unsqueeze(0)

    if torch.cuda.is_available():
        return x.cuda(), y.cuda(), x_mask.cuda()
    else:
        return x, y, x_mask 

def train(model, training_data, validation_data,
        epochs, batch_size, batches_per_print=100, evaluate_per=1,
        padding_index=-100, checkpoint_path=None,
        custom_schedule=False, custom_loss=False):
    """
    Training loop function.
    Args:
        model: MusicTransformer module
        training_data: List of encoded music sequences
        validation_data: List of encoded music sequences
        epochs: Number of iterations over training batches
        batch_size: _
        batches_per_print: How often to print training loss
        evaluate_per: calculate validation loss after this many epochs
        padding_index: ignore this sequence token in loss calculation
        checkpoint_path: (str or None) If defined, save the model's state dict to this file path after validation
        custom_schedule: (bool) If True, use a learning rate scheduler with a warmup ramp
        custom_loss: (bool) If True, set loss function as Cross Entropy with label smoothing
    """

    training_start_time = time.time()

    model.train()
    optimizer = torch.optim.Adam(model.parameters())

    if custom_schedule:
        optimizer = TFSchedule(optimizer, model.d_model)
    
    if custom_loss:
        loss_function = smooth_cross_entropy
    else:
        loss_function = nn.CrossEntropyLoss(ignore_index=padding_index)
    accuracy = Accuracy()

    if torch.cuda.is_available():
        model.cuda()
        print("GPU is available")
    else:
        print("GPU not available, CPU used")

    training_losses = []
    validation_losses = []
    #pad to length of longest sequence
    #minus one because input/target sequences are shifted by one char
    max_length = max((len(L) 
        for L in (training_data + validation_data))) - 1
    for e in range(epochs):
        batch_start_time = time.time()
        batch_num = 1
        averaged_loss = 0
        averaged_accuracy = 0
        training_batches = prepare_batches(training_data, batch_size) #returning batches of a given size
        for batch in training_batches:

            #skip batches that are undersized
            if len(batch[0]) != batch_size:
                continue
            x, y, x_mask = batch_to_tensors(batch, model.n_tokens, 
                    max_length)
            y_hat = model(x, x_mask).transpose(1,2)

            #shape: (batch_size, n_tokens, seq_length)

            loss = loss_function(y_hat, y)

            #detach hidden state from the computation graph; we don't need its gradient
            #clear old gradients from previous step
            model.zero_grad()
            #compute derivative of loss w/r/t parameters
            loss.backward()
            #optimizer takes a step based on gradient
            optimizer.step()
            training_loss = loss.item()
            training_losses.append(training_loss)
            #take average over subset of batch?
            averaged_loss += training_loss
            averaged_accuracy += accuracy(y_hat, y, x_mask)
            if batch_num % batches_per_print == 0:
                print(f"batch {batch_num}, loss: {averaged_loss / batches_per_print : .2f}")
                print(f"accuracy: {averaged_accuracy / batches_per_print : .2f}")
                averaged_loss = 0
                averaged_accuracy = 0
            batch_num += 1

        print(f"epoch: {e+1}/{epochs} | time: {(time.time() - batch_start_time) / 60:,.0f}m")
        shuffle(training_data)

        if (e + 1) % evaluate_per == 0:

            #deactivate backprop for evaluation
            model.eval()
            validation_batches = prepare_batches(validation_data,
                    batch_size)
            #get loss per batch
            val_loss = 0
            n_batches = 0
            val_accuracy = 0
            for batch in validation_batches:

                if len(batch[0]) != batch_size:
                    continue

                x, y, x_mask = batch_to_tensors(batch, model.n_tokens, 
                        max_length)

                y_hat = model(x, x_mask).transpose(1,2)
                loss = loss_function(y_hat, y)
                val_loss += loss.item()
                val_accuracy += accuracy(y_hat, y, x_mask)
                n_batches += 1

            if checkpoint_path is not None:
                try:
                    torch.save(model.state_dict(),
                            checkpoint_path+f"_e{e}")
                    print("Checkpoint saved!")
                except:
                    print("Error: checkpoint could not be saved...")

            model.train()
            #average out validation loss
            val_accuracy = (val_accuracy / n_batches)
            val_loss = (val_loss / n_batches)
            validation_losses.append(val_loss)
            print(f"validation loss: {val_loss:.2f}")
            print(f"validation accuracy: {val_accuracy:.2f}")
            shuffle(validation_data)

    return training_losses

In [None]:
# run code
def main():
    # parser = argparse.ArgumentParser("Script to train model on a GPU")
    # parser.add_argument("--checkpoint", type=str, default=None,
    #         help="Optional path to saved model, if none provided, the model is trained from scratch.")
    # parser.add_argument("--n_epochs", type=int, default=5,
    #         help="Number of training epochs.")
    # args = parser.parse_args()
    args = {}
    args['checkpoint'] = None
    args['--n_epochs'] = 5
    
    sampling_rate = 125
    n_velocity_bins = 32
    seq_length = 1024
    n_tokens = 256 + sampling_rate + n_velocity_bins
    transformer = MusicTransformer(n_tokens, seq_length, 
            d_model = 64, n_heads = 8, d_feedforward=256, 
            depth = 4, positional_encoding=True, relative_pos=True)

    if args['checkpoint'] is not None:
        state = torch.load(args['checkpoint'])
        transformer.load_state_dict(state)
        print(f"Successfully loaded checkpoint at {args['checkpoint']}")
    #rule of thumb: 1 minute is roughly 2k tokens
    
    pipeline = PreprocessingPipeline(input_dir="data", stretch_factors=[0.975, 1, 1.025],
            split_size=30, sampling_rate=sampling_rate, n_velocity_bins=n_velocity_bins,
            transpositions=range(-2,3), training_val_split=0.9, max_encoded_length=seq_length+1,
                                    min_encoded_length=257)
    pipeline_start = time.time()
    pipeline.run()
    runtime = time.time() - pipeline_start
    print(f"MIDI pipeline runtime: {runtime / 60 : .1f}m")

    today = datetime.date.today().strftime('%m%d%Y')
    checkpoint = f"saved_models/tf_{today}"

    training_sequences = pipeline.encoded_sequences['training']
    validation_sequences = pipeline.encoded_sequences['validation']
    
    batch_size = 16
    
    train(transformer, training_sequences, validation_sequences,
               epochs = args['n_epochs'], evaluate_per = 1,
               batch_size = batch_size, batches_per_print=100,
               padding_index=0, checkpoint_path=checkpoint)

# **Run Code**

In [None]:
# code to reference 
# https://github.com/chathasphere/pno-ai

In [None]:
# testing the reference code
! git clone https://github.com/chathasphere/pno-ai

Cloning into 'pno-ai'...
remote: Enumerating objects: 94, done.[K
remote: Counting objects: 100% (94/94), done.[K
remote: Compressing objects: 100% (62/62), done.[K
remote: Total 340 (delta 53), reused 66 (delta 31), pack-reused 246[K
Receiving objects: 100% (340/340), 99.72 KiB | 9.06 MiB/s, done.
Resolving deltas: 100% (207/207), done.


In [None]:
! cd pno-ai/ && wget https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip && unzip -q maestro-v2.0.0-midi.zip -d data/

--2021-03-30 21:21:49--  https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.20.128, 74.125.142.128, 74.125.195.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|74.125.20.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 59243107 (56M) [application/zip]
Saving to: ‘maestro-v2.0.0-midi.zip’


2021-03-30 21:21:50 (70.7 MB/s) - ‘maestro-v2.0.0-midi.zip’ saved [59243107/59243107]



In [None]:
# removing some of the data to allow it to run on collab
! cd pno-ai/data/maestro-v2.0.0/ && rm -rf 2004/ && rm -rf 2006 && rm -rf 2008 && rm -rf 2009

In [None]:
%cd pno-ai/
! python3 run.py
# main()

[Errno 2] No such file or directory: 'pno-ai/'
/content/pno-ai
Parsing 129 midi files in /content/pno-ai/data/maestro-v2.0.0/2015...
Could not parse MIDI-Unprocessed_R1_D1-9-12_mid--AUDIO-from_mp3_09_R1_2015_wav--2.midi
Traceback (most recent call last):
  File "run.py", line 59, in <module>
    main()
  File "run.py", line 40, in main
    pipeline.run()
  File "/content/pno-ai/preprocess/pipeline.py", line 75, in run
    midis = self.parse_files(chdir=True) 
  File "/content/pno-ai/preprocess/pipeline.py", line 116, in parse_files
    pretty_midis += self.parse_files()
  File "/content/pno-ai/preprocess/pipeline.py", line 116, in parse_files
    pretty_midis += self.parse_files()
  File "/content/pno-ai/preprocess/pipeline.py", line 128, in parse_files
    print("Could not parse {}".format(m))
KeyboardInterrupt


# **SetUp** **Imports**

In [None]:
! git clone https://github.com/czhuang/JSB-Chorales-dataset.git

# Data Loader

In [None]:
# import pickle
# import numpy as np
# with open('JSB-Chorales-dataset/jsb-chorales-16th.pkl', 'rb') as p:
#     data = pickle.load(p, encoding="latin1")
# test_data = data['test']
# train_data = data['train']
# valid_data = data['valid']

# print(valid_data)
# # creating the dataloaders


In [None]:
# Using maestro data
! cd pno-ai/
from preprocess import PreprocessingPipeline

# LSMT Model

In [None]:
class LSTMClassifier(nn.Module):
    """
    A regular one layer LSTM classifier using LSTMCell -> FFNetwork structure
    """
    def __init__(self, input_dim, hidden_dim, label_size, device=torch.device("cuda"), dropout_rate=0.1):
        super().__init__()
        self.lstm = nn.LSTMCell(input_dim, hidden_dim)
        self.hidden2ff = nn.Linear(hidden_dim,  int(np.sqrt(hidden_dim)))
        self.ff2label = nn.Linear(int(np.sqrt(hidden_dim)), label_size)
        self.hidden_dim = hidden_dim
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(dropout_rate)
        self.device = device
        self.initialize_weights()

    def initialize_weights(self):
        for name, param in self.lstm.named_parameters():
            if 'bias' in name:
                nn.init.constant_(param, 0.0001)
            elif 'weight' in name:
                nn.init.xavier_normal_(param)
        for name, param in self.hidden2ff.named_parameters():
            if 'bias' in name:
                nn.init.constant_(param, 0.0001)
            elif 'weight' in name:
                nn.init.xavier_normal_(param)
        for name, param in self.ff2label.named_parameters():
            if 'bias' in name:
                nn.init.constant_(param, 0.0001)
            elif 'weight' in name:
                nn.init.xavier_normal_(param)

    def forward(self, x):

        hs = torch.zeros(x.size(0), self.hidden_dim).to(self.device)
        cs = torch.zeros(x.size(0), self.hidden_dim).to(self.device)

        for i in range(x.size()[1]):
            hs, cs = self.lstm(x[:, i], (hs, cs))

        hs = self.dropout(hs)
        hs = self.hidden2ff(hs)
        return self.sigmoid(self.ff2label(hs))

In [None]:
#Longformer code 

Longformer Attention Layer

changing attention 

optimizing initialization 

## Transform Base Model