In [48]:
import glob
import random
from tqdm import tnrange, tqdm_notebook, tqdm
import pretty_midi
import numpy as np
import os
import io
import sys
import traceback
import tensorflow as tf
from tensorflow.keras import backend as K
import IPython
import pickle


seed_int=666
list_all_midi = glob.glob("C:/Users/VARNA/Desktop/composser/midi/*.midi")
random.shuffle(list_all_midi)
#print(len(list_all_midi))
sample_midi = list_all_midi[0:100]  
#print(sample_midi)

# Preprocessing MIDI files

In [49]:
class NoteTokenizer:
    
    def __init__(self):
      self.notes_to_index = {}
      self.index_to_notes = {}
      self.num_of_word = 0
      self.unique_word = 0
      self.notes_freq = {}
        
    def transform(self,list_array):
      """ Transform a list of note in string into index.
      
      Parameters
      ==========
      list_array : list
        list of note in string format
      
      Returns
      =======
      The transformed list in numpy array.
      
      """
      transformed_list = []
      for instance in list_array:
          transformed_list.append([self.notes_to_index[note] for note in instance])
      return np.array(transformed_list, dtype=np.int32)
 
    def partial_fit(self, notes):
        """ Partial fit on the dictionary of the tokenizer
        
        Parameters
        ==========
        notes : list of notes
        
        """
        for note in notes:
            note_str = ','.join(str(a) for a in note)
            if note_str in self.notes_freq:
                self.notes_freq[note_str] += 1
                self.num_of_word += 1
            else:
                self.notes_freq[note_str] = 1
                self.unique_word += 1
                self.num_of_word += 1
                self.notes_to_index[note_str], self.index_to_notes[self.unique_word] = self.unique_word, note_str
            
    def add_new_note(self, note):
        """ Add a new note into the dictionary

        Parameters
        ==========
        note : str
          a new note who is not in dictionary.  

        """
        assert note not in self.notes_to_index
        self.unique_word += 1
        self.notes_to_index[note], self.index_to_notes[self.unique_word] = self.unique_word, note

In [50]:
def generate_dict_time_notes(list_all_midi, batch_song = 16, start_index=0, fs=30, use_tqdm=True):
    """ Generate map (dictionary) of music ( in index ) to piano_roll (in np.array)

    Parameters
    ==========
    list_all_midi : list
        List of midi files
    batch_music : int
      A number of music in one batch
    start_index : int
      The start index to be batched in list_all_midi
    fs : int
      Sampling frequency of the columns, i.e. each column is spaced apart
        by ``1./fs`` seconds.
    use_tqdm : bool
      Whether to use tqdm or not in the function

    Returns
    =======
    dictionary of music to piano_roll (in np.array)

    """
    assert len(list_all_midi) >= batch_song
    
    dict_time_notes = {}
    process_tqdm_midi = tqdm_notebook(range(start_index, min(start_index + batch_song, len(list_all_midi)))) if use_tqdm else range(start_index,  min(start_index + batch_song, len(list_all_midi)))
    for i in process_tqdm_midi:
        midi_file_name = list_all_midi[i]
        if use_tqdm:
            process_tqdm_midi.set_description("Processing {}".format(midi_file_name))
        try: # Handle exception on malformat MIDI files
            midi_pretty_format = pretty_midi.PrettyMIDI(midi_file_name)
            piano_midi = midi_pretty_format.instruments[0] # Get the piano channels
            piano_roll = piano_midi.get_piano_roll(fs=fs)
            dict_time_notes[i] = piano_roll
        except Exception as e:
            print(e)
            print("broken file : {}".format(midi_file_name))
            pass
    return dict_time_notes

def generate_input_and_target(dict_keys_time, seq_len=50):
    """ Generate input and the target of our deep learning for one music.
    
    Parameters
    ==========
    dict_keys_time : dict
      Dictionary of timestep and notes
    seq_len : int
      The length of the sequence
      
    Returns
    =======
    Tuple of list of input and list of target of neural network.
    
       
    """
    # Get the start time and end time
    start_time, end_time = list(dict_keys_time.keys())[0], list(dict_keys_time.keys())[-1]
    list_training, list_target = [], []
    for index_enum, time in enumerate(range(start_time, end_time)):
        list_append_training, list_append_target = [], []
        start_iterate = 0
        flag_target_append = False # flag to append the test list
        if index_enum < seq_len:
            start_iterate = seq_len - index_enum - 1
            for i in range(start_iterate): # add 'e' to the seq list. 
                list_append_training.append('e')
                flag_target_append = True

        for i in range(start_iterate,seq_len):
            index_enum = time - (seq_len - i - 1)
            if index_enum in dict_keys_time:
                list_append_training.append(','.join(str(x) for x in dict_keys_time[index_enum]))      
            else:
                list_append_training.append('e')

        # add time + 1 to the list_append_target
        if time+1 in dict_keys_time:
            list_append_target.append(','.join(str(x) for x in dict_keys_time[time+1]))
        else:
            list_append_target.append('e')
        list_training.append(list_append_training)
        list_target.append(list_append_target)
    return list_training, list_target

def process_notes_in_song(dict_time_notes, seq_len = 50):
    """
    Iterate the dict of piano rolls into dictionary of timesteps and note played
    
    Parameters
    ==========
    dict_time_notes : dict
      dict contains index of music ( in index ) to piano_roll (in np.array)
    seq_len : int
      Length of the sequence
      
    Returns
    =======
    Dict of timesteps and note played
    """
    list_of_dict_keys_time = []
    
    for key in dict_time_notes:
        sample = dict_time_notes[key]
        times = np.unique(np.where(sample > 0)[1])
        index = np.where(sample > 0)
        dict_keys_time = {}

        for time in times:
            index_where = np.where(index[1] == time)
            notes = index[0][index_where]
            dict_keys_time[time] = notes
        list_of_dict_keys_time.append(dict_keys_time)
    return list_of_dict_keys_time


In [51]:
batch = 1
start_index = 0
note_tokenizer = NoteTokenizer()
#tqdm is used for showing the progress bars
for i in tqdm_notebook(range(len(sample_midi))):
    dict_time_notes = generate_dict_time_notes(sample_midi, batch_song=1, start_index=i, use_tqdm=False, fs=5)
    full_notes = process_notes_in_song(dict_time_notes)
    #print(full_notes)
    for note in full_notes:
        note_tokenizer.partial_fit(list(note.values()))

HBox(children=(IntProgress(value=0), HTML(value='')))

In [52]:
note_tokenizer.add_new_note('e')
unique_notes = note_tokenizer.unique_word
print(unique_notes)

38634


# Architecture
The architecture is as follow:
1. Embedding(used as only first layer.to convert to 2D to 3D)
2. LSTM
3. Self Head Attention
4. LSTM
5. Self Head Attention
6. Dense

In [53]:
seq_len = 50
EPOCHS = 4
BATCH_SONG = 16
BATCH_NNET_SIZE = 96
TOTAL_SONGS = len(sample_midi)
FRAME_PER_SECOND = 5

In [54]:
class SeqSelfAttention(tf.keras.layers.Layer):

    ATTENTION_TYPE_ADD = 'additive'
    ATTENTION_TYPE_MUL = 'multiplicative'

    def __init__(self,
                 units=32,
                 attention_width=None,
                 attention_type=ATTENTION_TYPE_ADD,
                 return_attention=False,
                 history_only=False,
                 kernel_initializer='glorot_normal',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 use_additive_bias=True,
                 use_attention_bias=True,
                 attention_activation=None,
                 attention_regularizer_weight=0.0,
                 **kwargs):
        """Layer initialization.
        For additive attention, see: https://arxiv.org/pdf/1806.01264.pdf
        :param units: The dimension of the vectors that used to calculate the attention weights.
        :param attention_width: The width of local attention.
        :param attention_type: 'additive' or 'multiplicative'.
        :param return_attention: Whether to return the attention weights for visualization.
        :param history_only: Only use historical pieces of data.
        :param kernel_initializer: The initializer for weight matrices.
        :param bias_initializer: The initializer for biases.
        :param kernel_regularizer: The regularization for weight matrices.
        :param bias_regularizer: The regularization for biases.
        :param kernel_constraint: The constraint for weight matrices.
        :param bias_constraint: The constraint for biases.
        :param use_additive_bias: Whether to use bias while calculating the relevance of inputs features
                                  in additive mode.
        :param use_attention_bias: Whether to use bias while calculating the weights of attention.
        :param attention_activation: The activation used for calculating the weights of attention.
        :param attention_regularizer_weight: The weights of attention regularizer.
        :param kwargs: Parameters for parent class.
        """
        self.supports_masking = True
        self.units = units
        self.attention_width = attention_width
        self.attention_type = attention_type
        self.return_attention = return_attention
        self.history_only = history_only
        if history_only and attention_width is None:
            self.attention_width = int(1e9)

        self.use_additive_bias = use_additive_bias
        self.use_attention_bias = use_attention_bias
        self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
        self.bias_initializer = tf.keras.initializers.get(bias_initializer)
        self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
        self.bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
        self.kernel_constraint = tf.keras.constraints.get(kernel_constraint)
        self.bias_constraint = tf.keras.constraints.get(bias_constraint)
        self.attention_activation = tf.keras.activations.get(attention_activation)
        self.attention_regularizer_weight = attention_regularizer_weight
        self._backend = tf.keras.backend.backend()

        if attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD:
            self.Wx, self.Wt, self.bh = None, None, None
            self.Wa, self.ba = None, None
        elif attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL:
            self.Wa, self.ba = None, None
        else:
            raise NotImplementedError('No implementation for attention type : ' + attention_type)

        super(SeqSelfAttention, self).__init__(**kwargs)

    def get_config(self):
        config = {
            'units': self.units,
            'attention_width': self.attention_width,
            'attention_type': self.attention_type,
            'return_attention': self.return_attention,
            'history_only': self.history_only,
            'use_additive_bias': self.use_additive_bias,
            'use_attention_bias': self.use_attention_bias,
            'kernel_initializer': tf.keras.regularizers.serialize(self.kernel_initializer),
            'bias_initializer': tf.keras.regularizers.serialize(self.bias_initializer),
            'kernel_regularizer': tf.keras.regularizers.serialize(self.kernel_regularizer),
            'bias_regularizer': tf.keras.regularizers.serialize(self.bias_regularizer),
            'kernel_constraint': tf.keras.constraints.serialize(self.kernel_constraint),
            'bias_constraint': tf.keras.constraints.serialize(self.bias_constraint),
            'attention_activation': tf.keras.activations.serialize(self.attention_activation),
            'attention_regularizer_weight': self.attention_regularizer_weight,
        }
        base_config = super(SeqSelfAttention, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def build(self, input_shape):
        if isinstance(input_shape, list):
            input_shape = input_shape[0]
        if self.attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD:
            self._build_additive_attention(input_shape)
        elif self.attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL:
            self._build_multiplicative_attention(input_shape)
        super(SeqSelfAttention, self).build(input_shape)

    def _build_additive_attention(self, input_shape):
        feature_dim = input_shape[2]

        self.Wt = self.add_weight(shape=(feature_dim, self.units),
                                  name='{}_Add_Wt'.format(self.name),
                                  initializer=self.kernel_initializer,
                                  regularizer=self.kernel_regularizer,
                                  constraint=self.kernel_constraint)
        self.Wx = self.add_weight(shape=(feature_dim, self.units),
                                  name='{}_Add_Wx'.format(self.name),
                                  initializer=self.kernel_initializer,
                                  regularizer=self.kernel_regularizer,
                                  constraint=self.kernel_constraint)
        if self.use_additive_bias:
            self.bh = self.add_weight(shape=(self.units,),
                                      name='{}_Add_bh'.format(self.name),
                                      initializer=self.bias_initializer,
                                      regularizer=self.bias_regularizer,
                                      constraint=self.bias_constraint)

        self.Wa = self.add_weight(shape=(self.units, 1),
                                  name='{}_Add_Wa'.format(self.name),
                                  initializer=self.kernel_initializer,
                                  regularizer=self.kernel_regularizer,
                                  constraint=self.kernel_constraint)
        if self.use_attention_bias:
            self.ba = self.add_weight(shape=(1,),
                                      name='{}_Add_ba'.format(self.name),
                                      initializer=self.bias_initializer,
                                      regularizer=self.bias_regularizer,
                                      constraint=self.bias_constraint)

    def _build_multiplicative_attention(self, input_shape):
        feature_dim = input_shape[2]

        self.Wa = self.add_weight(shape=(feature_dim, feature_dim),
                                  name='{}_Mul_Wa'.format(self.name),
                                  initializer=self.kernel_initializer,
                                  regularizer=self.kernel_regularizer,
                                  constraint=self.kernel_constraint)
        if self.use_attention_bias:
            self.ba = self.add_weight(shape=(1,),
                                      name='{}_Mul_ba'.format(self.name),
                                      initializer=self.bias_initializer,
                                      regularizer=self.bias_regularizer,
                                      constraint=self.bias_constraint)

    def call(self, inputs, mask=None, **kwargs):
        if isinstance(inputs, list):
            inputs, positions = inputs
            positions = K.cast(positions, 'int32')
            mask = mask[1]
        else:
            positions = None

        input_len = K.shape(inputs)[1]

        if self.attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD:
            e = self._call_additive_emission(inputs)
        elif self.attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL:
            e = self._call_multiplicative_emission(inputs)

        if self.attention_activation is not None:
            e = self.attention_activation(e)
        e = K.exp(e - K.max(e, axis=-1, keepdims=True))
        if self.attention_width is not None:
            ones = tf.ones((input_len, input_len))
            if self.history_only:
                local = tf.linalg.band_part(
                    ones,
                    K.minimum(input_len, self.attention_width - 1),
                    0,
                )
            else:
                local = tf.linalg.band_part(
                    ones,
                    K.minimum(input_len, self.attention_width // 2),
                    K.minimum(input_len, (self.attention_width - 1) // 2),
                )
            e = e * K.expand_dims(local, 0)
        if mask is not None:
            mask = K.cast(mask, K.floatx())
            mask = K.expand_dims(mask)
            e = K.permute_dimensions(K.permute_dimensions(e * mask, (0, 2, 1)) * mask, (0, 2, 1))

        # a_{t} = \text{softmax}(e_t)
        s = K.sum(e, axis=-1)
        s = K.tile(K.expand_dims(s, axis=-1), K.stack([1, 1, input_len]))
        a = e / (s + K.epsilon())

        # l_t = \sum_{t'} a_{t, t'} x_{t'}
        v = K.batch_dot(a, inputs)
        if self.attention_regularizer_weight > 0.0:
            self.add_loss(self._attention_regularizer(a))

        if positions is not None:
            pos_num = K.shape(positions)[1]
            batch_indices = K.tile(K.expand_dims(K.arange(K.shape(inputs)[0]), axis=-1), K.stack([1, pos_num]))
            pos_indices = K.stack([batch_indices, positions], axis=-1)
            v = tf.gather_nd(v, pos_indices)
            a = tf.gather_nd(a, pos_indices)

        if self.return_attention:
            return [v, a]
        return v

    def _call_additive_emission(self, inputs):
        input_shape = K.shape(inputs)
        batch_size, input_len = input_shape[0], input_shape[1]

        # h_{t, t'} = \tanh(x_t^T W_t + x_{t'}^T W_x + b_h)
        q, k = K.dot(inputs, self.Wt), K.dot(inputs, self.Wx)
        q = K.tile(K.expand_dims(q, 2), K.stack([1, 1, input_len, 1]))
        k = K.tile(K.expand_dims(k, 1), K.stack([1, input_len, 1, 1]))
        if self.use_additive_bias:
            h = K.tanh(q + k + self.bh)
        else:
            h = K.tanh(q + k)

        # e_{t, t'} = W_a h_{t, t'} + b_a
        if self.use_attention_bias:
            e = K.reshape(K.dot(h, self.Wa) + self.ba, (batch_size, input_len, input_len))
        else:
            e = K.reshape(K.dot(h, self.Wa), (batch_size, input_len, input_len))
        return e

    def _call_multiplicative_emission(self, inputs):
        # e_{t, t'} = x_t^T W_a x_{t'} + b_a
        e = K.batch_dot(K.dot(inputs, self.Wa), K.permute_dimensions(inputs, (0, 2, 1)))
        if self.use_attention_bias:
            e = e + self.ba
        return e

    def compute_output_shape(self, input_shape):
        if isinstance(input_shape, list):
            input_shape, pos_shape = input_shape
            output_shape = (input_shape[0], pos_shape[1], input_shape[2])
        else:
            output_shape = input_shape
        if self.return_attention:
            attention_shape = (input_shape[0], output_shape[1], input_shape[1])
            return [output_shape, attention_shape]
        return output_shape

    def compute_mask(self, inputs, mask=None):
        if isinstance(inputs, list):
            mask = mask[1]
        if self.return_attention:
            return [mask, None]
        return mask

    def _attention_regularizer(self, attention):
        batch_size = K.cast(K.shape(attention)[0], K.floatx())
        input_len = K.shape(attention)[-1]
        return self.attention_regularizer_weight * K.sum(K.square(K.batch_dot(
            attention,
            K.permute_dimensions(attention, (0, 2, 1))) - tf.eye(input_len))) / batch_size

    @staticmethod
    def get_custom_objects():
      return {'SeqSelfAttention': SeqSelfAttention}

In [55]:
def create_model(seq_len, unique_notes, dropout=0.3, output_emb=100, rnn_unit=128, dense_unit=64):
  inputs = tf.keras.layers.Input(shape=(seq_len,))
  embedding = tf.keras.layers.Embedding(input_dim=unique_notes+1, output_dim=output_emb, input_length=seq_len)(inputs)
  forward_pass = tf.keras.layers.Bidirectional(tf.keras.layers.GRU(rnn_unit, return_sequences=True))(embedding)
  forward_pass , att_vector = SeqSelfAttention(
      return_attention=True,
      attention_activation='sigmoid', 
      attention_type=SeqSelfAttention.ATTENTION_TYPE_MUL,
      attention_width=50, 
      kernel_regularizer=tf.keras.regularizers.l2(1e-4),
      bias_regularizer=tf.keras.regularizers.l1(1e-4),
      attention_regularizer_weight=1e-4,
  )(forward_pass)
  forward_pass = tf.keras.layers.Dropout(dropout)(forward_pass)
  forward_pass = tf.keras.layers.Bidirectional(tf.keras.layers.GRU(rnn_unit, return_sequences=True))(forward_pass)
  forward_pass , att_vector2 = SeqSelfAttention(
      return_attention=True,
      attention_activation='sigmoid', 
      attention_type=SeqSelfAttention.ATTENTION_TYPE_MUL,
      attention_width=50, 
      kernel_regularizer=tf.keras.regularizers.l2(1e-4),
      bias_regularizer=tf.keras.regularizers.l1(1e-4),
      attention_regularizer_weight=1e-4,
  )(forward_pass)
  forward_pass = tf.keras.layers.Dropout(dropout)(forward_pass)
  forward_pass = tf.keras.layers.Bidirectional(tf.keras.layers.GRU(rnn_unit))(forward_pass)
  forward_pass = tf.keras.layers.Dropout(dropout)(forward_pass)
  forward_pass = tf.keras.layers.Dense(dense_unit)(forward_pass)
  forward_pass = tf.keras.layers.LeakyReLU()(forward_pass)
  outputs = tf.keras.layers.Dense(unique_notes+1, activation = "softmax")(forward_pass)

  model = tf.keras.Model(inputs=inputs, outputs=outputs, name='generate_scores_rnn')
  return model

model = create_model(seq_len, unique_notes)

In [56]:
model.summary()

Model: "generate_scores_rnn"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 50)]              0         
_________________________________________________________________
embedding_2 (Embedding)      (None, 50, 100)           3863500   
_________________________________________________________________
bidirectional_4 (Bidirection (None, 50, 256)           176640    
_________________________________________________________________
seq_self_attention_3 (SeqSel [(None, 50, 256), (None,  65537     
_________________________________________________________________
dropout_3 (Dropout)          (None, 50, 256)           0         
_________________________________________________________________
bidirectional_5 (Bidirection (None, 50, 256)           296448    
_________________________________________________________________
seq_self_attention_4 (SeqSel [(None, 50, 256), 

# Train

In [57]:
import os
from tensorflow.keras.optimizers import Nadam
from tensorflow.keras.losses import sparse_categorical_crossentropy
from random import shuffle, seed
optimizer = Nadam()

checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 model=model)
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
loss_fn = sparse_categorical_crossentropy

In [58]:
def generate_batch_song(list_all_midi, batch_music=16, start_index=0, fs=30, seq_len=50, use_tqdm=False):
    """
    Generate Batch music that will be used to be input and output of the neural network
    
    Parameters
    ==========
    list_all_midi : list
      List of midi files
    batch_music : int
      A number of music in one batch
    start_index : int
      The start index to be batched in list_all_midi
    fs : int
      Sampling frequency of the columns, i.e. each column is spaced apart
        by ``1./fs`` seconds.
    seq_len : int
      The sequence length of the music to be input of neural network
    use_tqdm : bool
      Whether to use tqdm or not in the function
    
    Returns
    =======
    Tuple of input and target neural network
    
    """
    
    assert len(list_all_midi) >= batch_music
    dict_time_notes = generate_dict_time_notes(list_all_midi, batch_music, start_index, fs, use_tqdm=use_tqdm)
    
    list_musics = process_notes_in_song(dict_time_notes, seq_len)
    collected_list_input, collected_list_target = [], []
     
    for music in list_musics:
        list_training, list_target = generate_input_and_target(music, seq_len)
        collected_list_input += list_training
        collected_list_target += list_target
    return collected_list_input, collected_list_target


In [59]:
class TrainModel:
  
  def __init__(self, epochs, note_tokenizer, sample_midi, frame_per_second, 
               batch_nnet_size, batch_song, optimizer, checkpoint, loss_fn,
               checkpoint_prefix, total_songs, model):
    self.epochs = epochs
    self.note_tokenizer = note_tokenizer
    self.sample_midi = sample_midi
    self.frame_per_second = frame_per_second
    self.batch_nnet_size = batch_nnet_size
    self.batch_song = batch_song
    self.optimizer = optimizer
    self.checkpoint = checkpoint
    self.loss_fn = loss_fn
    self.checkpoint_prefix = checkpoint_prefix
    self.total_songs = total_songs
    self.model = model
    
  def train(self):
    for epoch in tqdm_notebook(range(self.epochs),desc='epochs'):
      # for each epochs, we shufle the list of all the datasets
      shuffle(self.sample_midi)
      loss_total = 0
      steps = 0
      steps_nnet = 0

      # We will iterate all songs by self.song_size
      for i in tqdm_notebook(range(0,self.total_songs, self.batch_song), desc='MUSIC'):

        steps += 1
        inputs_nnet_large, outputs_nnet_large = generate_batch_song(
            self.sample_midi, self.batch_song, start_index=i, fs=self.frame_per_second, 
            seq_len=seq_len, use_tqdm=False) # We use the function that have been defined here
        inputs_nnet_large = np.array(self.note_tokenizer.transform(inputs_nnet_large), dtype=np.int32)
        outputs_nnet_large = np.array(self.note_tokenizer.transform(outputs_nnet_large), dtype=np.int32)

        index_shuffled = np.arange(start=0, stop=len(inputs_nnet_large))
        np.random.shuffle(index_shuffled)

        for nnet_steps in tqdm_notebook(range(0,len(index_shuffled),self.batch_nnet_size)):
          steps_nnet += 1
          current_index = index_shuffled[nnet_steps:nnet_steps+self.batch_nnet_size]
          inputs_nnet, outputs_nnet = inputs_nnet_large[current_index], outputs_nnet_large[current_index]
          
          # To make sure no exception thrown by tensorflow on autograph
          if len(inputs_nnet) // self.batch_nnet_size != 1:
            break
          loss = self.train_step(inputs_nnet, outputs_nnet)
          loss_total += tf.math.reduce_sum(loss)
          if steps_nnet % 20 == 0:
            print("epochs {} | Steps {} | total loss : {}".format(epoch + 1, steps_nnet, loss_total))

      checkpoint.save(file_prefix = self.checkpoint_prefix)
  
  @tf.function
  def train_step(self, inputs, targets):
    with tf.GradientTape() as tape:
      prediction = self.model(inputs)
      loss = self.loss_fn(targets, prediction)
    gradients = tape.gradient(loss, self.model.trainable_variables)
    self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
    return loss

In [60]:

seq_len = 50
EPOCHS = 4
BATCH_SONG = 16
BATCH_NNET_SIZE = 96
TOTAL_SONGS = len(sample_midi)
FRAME_PER_SECOND = 5

train_class = TrainModel(EPOCHS, note_tokenizer, sample_midi, FRAME_PER_SECOND,
                  BATCH_NNET_SIZE, BATCH_SONG, optimizer, checkpoint, loss_fn,
                  checkpoint_prefix, TOTAL_SONGS, model)

train_class.train()

HBox(children=(IntProgress(value=0, description='epochs', max=4, style=ProgressStyle(description_width='initia…

HBox(children=(IntProgress(value=0, description='MUSIC', max=7, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=393), HTML(value='')))

epochs 1 | Steps 20 | total loss : 17157.984375
epochs 1 | Steps 40 | total loss : 31662.171875
epochs 1 | Steps 60 | total loss : 45972.4453125
epochs 1 | Steps 80 | total loss : 59814.45703125
epochs 1 | Steps 100 | total loss : 73809.8046875
epochs 1 | Steps 120 | total loss : 87635.96875
epochs 1 | Steps 140 | total loss : 101202.6171875
epochs 1 | Steps 160 | total loss : 115059.640625
epochs 1 | Steps 180 | total loss : 129019.671875
epochs 1 | Steps 200 | total loss : 142440.53125
epochs 1 | Steps 220 | total loss : 156046.078125
epochs 1 | Steps 240 | total loss : 169455.734375
epochs 1 | Steps 260 | total loss : 182954.921875
epochs 1 | Steps 280 | total loss : 196510.28125
epochs 1 | Steps 300 | total loss : 209766.28125
epochs 1 | Steps 320 | total loss : 223068.15625
epochs 1 | Steps 340 | total loss : 236473.46875
epochs 1 | Steps 360 | total loss : 249827.625
epochs 1 | Steps 380 | total loss : 263156.65625


HBox(children=(IntProgress(value=0, max=443), HTML(value='')))

epochs 1 | Steps 400 | total loss : 276420.1875
epochs 1 | Steps 420 | total loss : 290842.21875
epochs 1 | Steps 440 | total loss : 305567.15625
epochs 1 | Steps 460 | total loss : 319905.25
epochs 1 | Steps 480 | total loss : 334348.3125
epochs 1 | Steps 500 | total loss : 348391.21875
epochs 1 | Steps 520 | total loss : 362631.40625
epochs 1 | Steps 540 | total loss : 376467.5
epochs 1 | Steps 560 | total loss : 390691.25
epochs 1 | Steps 580 | total loss : 404364.9375
epochs 1 | Steps 600 | total loss : 418220.40625
epochs 1 | Steps 620 | total loss : 432352.6875
epochs 1 | Steps 640 | total loss : 446562.90625
epochs 1 | Steps 660 | total loss : 460464.875
epochs 1 | Steps 680 | total loss : 474237.65625
epochs 1 | Steps 700 | total loss : 488149.1875
epochs 1 | Steps 720 | total loss : 501603.875
epochs 1 | Steps 740 | total loss : 515013.75
epochs 1 | Steps 760 | total loss : 528713.8125
epochs 1 | Steps 780 | total loss : 542122.875
epochs 1 | Steps 800 | total loss : 555782.31

HBox(children=(IntProgress(value=0, max=434), HTML(value='')))

epochs 1 | Steps 840 | total loss : 582693.3125
epochs 1 | Steps 860 | total loss : 596157.0
epochs 1 | Steps 880 | total loss : 610297.1875
epochs 1 | Steps 900 | total loss : 623898.1875
epochs 1 | Steps 920 | total loss : 637533.0625
epochs 1 | Steps 940 | total loss : 651363.4375
epochs 1 | Steps 960 | total loss : 664966.25
epochs 1 | Steps 980 | total loss : 678622.0625
epochs 1 | Steps 1000 | total loss : 692433.75
epochs 1 | Steps 1020 | total loss : 705864.75
epochs 1 | Steps 1040 | total loss : 719628.4375
epochs 1 | Steps 1060 | total loss : 732888.125
epochs 1 | Steps 1080 | total loss : 746511.3125
epochs 1 | Steps 1100 | total loss : 760081.6875
epochs 1 | Steps 1120 | total loss : 773345.0625
epochs 1 | Steps 1140 | total loss : 786700.125
epochs 1 | Steps 1160 | total loss : 799877.625
epochs 1 | Steps 1180 | total loss : 813268.875
epochs 1 | Steps 1200 | total loss : 826533.75
epochs 1 | Steps 1220 | total loss : 839966.625
epochs 1 | Steps 1240 | total loss : 853501.

HBox(children=(IntProgress(value=0, max=435), HTML(value='')))

epochs 1 | Steps 1280 | total loss : 879714.1875
epochs 1 | Steps 1300 | total loss : 893823.0
epochs 1 | Steps 1320 | total loss : 907884.9375
epochs 1 | Steps 1340 | total loss : 921689.6875
epochs 1 | Steps 1360 | total loss : 935495.125
epochs 1 | Steps 1380 | total loss : 949262.1875
epochs 1 | Steps 1400 | total loss : 962988.0
epochs 1 | Steps 1420 | total loss : 976510.0625
epochs 1 | Steps 1440 | total loss : 989929.8125
epochs 1 | Steps 1460 | total loss : 1003135.1875
epochs 1 | Steps 1480 | total loss : 1016555.625
epochs 1 | Steps 1500 | total loss : 1030124.6875
epochs 1 | Steps 1520 | total loss : 1043679.625
epochs 1 | Steps 1540 | total loss : 1057096.375
epochs 1 | Steps 1560 | total loss : 1070516.875
epochs 1 | Steps 1580 | total loss : 1083991.0
epochs 1 | Steps 1600 | total loss : 1097279.375
epochs 1 | Steps 1620 | total loss : 1110452.75
epochs 1 | Steps 1640 | total loss : 1123674.5
epochs 1 | Steps 1660 | total loss : 1136883.5
epochs 1 | Steps 1680 | total lo

HBox(children=(IntProgress(value=0, max=568), HTML(value='')))

epochs 1 | Steps 1720 | total loss : 1176410.25
epochs 1 | Steps 1740 | total loss : 1190505.375
epochs 1 | Steps 1760 | total loss : 1204234.875
epochs 1 | Steps 1780 | total loss : 1218282.5
epochs 1 | Steps 1800 | total loss : 1232235.0
epochs 1 | Steps 1820 | total loss : 1246104.25
epochs 1 | Steps 1840 | total loss : 1260128.0
epochs 1 | Steps 1860 | total loss : 1274053.0
epochs 1 | Steps 1880 | total loss : 1287934.125
epochs 1 | Steps 1900 | total loss : 1301383.875
epochs 1 | Steps 1920 | total loss : 1315177.5
epochs 1 | Steps 1940 | total loss : 1328631.5
epochs 1 | Steps 1960 | total loss : 1342075.625
epochs 1 | Steps 1980 | total loss : 1355619.25
epochs 1 | Steps 2000 | total loss : 1369081.25
epochs 1 | Steps 2020 | total loss : 1382309.75
epochs 1 | Steps 2040 | total loss : 1395944.125
epochs 1 | Steps 2060 | total loss : 1409303.375
epochs 1 | Steps 2080 | total loss : 1422659.375
epochs 1 | Steps 2100 | total loss : 1436384.5
epochs 1 | Steps 2120 | total loss : 14

HBox(children=(IntProgress(value=0, max=543), HTML(value='')))

epochs 1 | Steps 2280 | total loss : 1555956.5
epochs 1 | Steps 2300 | total loss : 1570239.875
epochs 1 | Steps 2320 | total loss : 1584312.375
epochs 1 | Steps 2340 | total loss : 1598676.375
epochs 1 | Steps 2360 | total loss : 1612772.25
epochs 1 | Steps 2380 | total loss : 1626994.25
epochs 1 | Steps 2400 | total loss : 1641106.875
epochs 1 | Steps 2420 | total loss : 1655162.25
epochs 1 | Steps 2440 | total loss : 1669152.5
epochs 1 | Steps 2460 | total loss : 1683231.875
epochs 1 | Steps 2480 | total loss : 1696998.625
epochs 1 | Steps 2500 | total loss : 1711203.25
epochs 1 | Steps 2520 | total loss : 1725626.875
epochs 1 | Steps 2540 | total loss : 1739752.125
epochs 1 | Steps 2560 | total loss : 1753882.875
epochs 1 | Steps 2580 | total loss : 1768034.125
epochs 1 | Steps 2600 | total loss : 1781917.5
epochs 1 | Steps 2620 | total loss : 1795824.25
epochs 1 | Steps 2640 | total loss : 1809946.75
epochs 1 | Steps 2660 | total loss : 1823935.125
epochs 1 | Steps 2680 | total lo

HBox(children=(IntProgress(value=0, max=165), HTML(value='')))

epochs 1 | Steps 2820 | total loss : 1934571.5
epochs 1 | Steps 2840 | total loss : 1949001.375
epochs 1 | Steps 2860 | total loss : 1962704.625
epochs 1 | Steps 2880 | total loss : 1976750.875
epochs 1 | Steps 2900 | total loss : 1990473.125
epochs 1 | Steps 2920 | total loss : 2003869.25
epochs 1 | Steps 2940 | total loss : 2016915.0
epochs 1 | Steps 2960 | total loss : 2030108.0
epochs 1 | Steps 2980 | total loss : 2043192.5


HBox(children=(IntProgress(value=0, description='MUSIC', max=7, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=475), HTML(value='')))

epochs 2 | Steps 20 | total loss : 13062.263671875
epochs 2 | Steps 40 | total loss : 26436.466796875
epochs 2 | Steps 60 | total loss : 39192.734375
epochs 2 | Steps 80 | total loss : 52431.89453125
epochs 2 | Steps 100 | total loss : 65761.1953125
epochs 2 | Steps 120 | total loss : 78785.1953125
epochs 2 | Steps 140 | total loss : 91904.453125
epochs 2 | Steps 160 | total loss : 104518.5625
epochs 2 | Steps 180 | total loss : 117062.4296875
epochs 2 | Steps 200 | total loss : 129896.015625
epochs 2 | Steps 220 | total loss : 142643.796875
epochs 2 | Steps 240 | total loss : 155630.78125
epochs 2 | Steps 260 | total loss : 168632.6875
epochs 2 | Steps 280 | total loss : 181198.46875
epochs 2 | Steps 300 | total loss : 193824.984375
epochs 2 | Steps 320 | total loss : 206417.828125
epochs 2 | Steps 340 | total loss : 218895.328125
epochs 2 | Steps 360 | total loss : 231660.390625
epochs 2 | Steps 380 | total loss : 244497.703125
epochs 2 | Steps 400 | total loss : 257145.640625
epochs

HBox(children=(IntProgress(value=0, max=463), HTML(value='')))

epochs 2 | Steps 480 | total loss : 307256.4375
epochs 2 | Steps 500 | total loss : 320651.625
epochs 2 | Steps 520 | total loss : 334182.96875
epochs 2 | Steps 540 | total loss : 347567.34375
epochs 2 | Steps 560 | total loss : 361228.96875
epochs 2 | Steps 580 | total loss : 374584.5
epochs 2 | Steps 600 | total loss : 388033.65625
epochs 2 | Steps 620 | total loss : 401548.3125
epochs 2 | Steps 640 | total loss : 414772.4375
epochs 2 | Steps 660 | total loss : 428022.875
epochs 2 | Steps 680 | total loss : 441254.1875
epochs 2 | Steps 700 | total loss : 454648.65625
epochs 2 | Steps 720 | total loss : 467848.1875
epochs 2 | Steps 740 | total loss : 481267.125
epochs 2 | Steps 760 | total loss : 494197.4375
epochs 2 | Steps 780 | total loss : 507310.875
epochs 2 | Steps 800 | total loss : 520493.09375
epochs 2 | Steps 820 | total loss : 533435.5625
epochs 2 | Steps 840 | total loss : 546703.0625
epochs 2 | Steps 860 | total loss : 560008.25
epochs 2 | Steps 880 | total loss : 572928.

HBox(children=(IntProgress(value=0, max=385), HTML(value='')))

epochs 2 | Steps 940 | total loss : 611783.5
epochs 2 | Steps 960 | total loss : 624793.6875
epochs 2 | Steps 980 | total loss : 637651.9375
epochs 2 | Steps 1000 | total loss : 650164.125
epochs 2 | Steps 1020 | total loss : 662496.375
epochs 2 | Steps 1040 | total loss : 674894.0625
epochs 2 | Steps 1060 | total loss : 687164.9375
epochs 2 | Steps 1080 | total loss : 699488.5625
epochs 2 | Steps 1100 | total loss : 711811.5625
epochs 2 | Steps 1120 | total loss : 724153.4375
epochs 2 | Steps 1140 | total loss : 736606.75
epochs 2 | Steps 1160 | total loss : 748696.125
epochs 2 | Steps 1180 | total loss : 761313.125
epochs 2 | Steps 1200 | total loss : 773607.6875
epochs 2 | Steps 1220 | total loss : 786025.625
epochs 2 | Steps 1240 | total loss : 798169.0625
epochs 2 | Steps 1260 | total loss : 810349.6875
epochs 2 | Steps 1280 | total loss : 822395.6875
epochs 2 | Steps 1300 | total loss : 834859.125
epochs 2 | Steps 1320 | total loss : 847003.0


HBox(children=(IntProgress(value=0, max=482), HTML(value='')))

epochs 2 | Steps 1340 | total loss : 859654.5
epochs 2 | Steps 1360 | total loss : 873089.0
epochs 2 | Steps 1380 | total loss : 886071.875
epochs 2 | Steps 1400 | total loss : 899296.1875
epochs 2 | Steps 1420 | total loss : 912411.125
epochs 2 | Steps 1440 | total loss : 925431.1875
epochs 2 | Steps 1460 | total loss : 938428.5625
epochs 2 | Steps 1480 | total loss : 951314.0625
epochs 2 | Steps 1500 | total loss : 964510.0
epochs 2 | Steps 1520 | total loss : 977467.25
epochs 2 | Steps 1540 | total loss : 990595.25
epochs 2 | Steps 1560 | total loss : 1003685.4375
epochs 2 | Steps 1580 | total loss : 1016478.1875
epochs 2 | Steps 1600 | total loss : 1029423.0
epochs 2 | Steps 1620 | total loss : 1042358.375
epochs 2 | Steps 1640 | total loss : 1054970.125
epochs 2 | Steps 1660 | total loss : 1067746.0
epochs 2 | Steps 1680 | total loss : 1080616.75
epochs 2 | Steps 1700 | total loss : 1093381.625
epochs 2 | Steps 1720 | total loss : 1106430.875
epochs 2 | Steps 1740 | total loss : 1

HBox(children=(IntProgress(value=0, max=475), HTML(value='')))

epochs 2 | Steps 1820 | total loss : 1170743.5
epochs 2 | Steps 1840 | total loss : 1184459.75
epochs 2 | Steps 1860 | total loss : 1198257.875
epochs 2 | Steps 1880 | total loss : 1211875.5
epochs 2 | Steps 1900 | total loss : 1225244.0
epochs 2 | Steps 1920 | total loss : 1238676.25
epochs 2 | Steps 1940 | total loss : 1252000.0
epochs 2 | Steps 1960 | total loss : 1265349.5
epochs 2 | Steps 1980 | total loss : 1278759.75
epochs 2 | Steps 2000 | total loss : 1292564.75
epochs 2 | Steps 2020 | total loss : 1305972.75
epochs 2 | Steps 2040 | total loss : 1319163.125
epochs 2 | Steps 2060 | total loss : 1332309.75
epochs 2 | Steps 2080 | total loss : 1345412.375
epochs 2 | Steps 2100 | total loss : 1358605.0
epochs 2 | Steps 2120 | total loss : 1371707.625
epochs 2 | Steps 2140 | total loss : 1384912.5
epochs 2 | Steps 2160 | total loss : 1398352.375
epochs 2 | Steps 2180 | total loss : 1411605.625
epochs 2 | Steps 2200 | total loss : 1424646.25
epochs 2 | Steps 2220 | total loss : 1437

HBox(children=(IntProgress(value=0, max=535), HTML(value='')))

epochs 2 | Steps 2300 | total loss : 1489904.375
epochs 2 | Steps 2320 | total loss : 1503155.25
epochs 2 | Steps 2340 | total loss : 1516582.0
epochs 2 | Steps 2360 | total loss : 1529781.625
epochs 2 | Steps 2380 | total loss : 1543209.5
epochs 2 | Steps 2400 | total loss : 1556054.875
epochs 2 | Steps 2420 | total loss : 1568976.25
epochs 2 | Steps 2440 | total loss : 1581970.25
epochs 2 | Steps 2460 | total loss : 1594988.75
epochs 2 | Steps 2480 | total loss : 1608042.0
epochs 2 | Steps 2500 | total loss : 1621005.25
epochs 2 | Steps 2520 | total loss : 1634240.375
epochs 2 | Steps 2540 | total loss : 1647275.0
epochs 2 | Steps 2560 | total loss : 1660270.75
epochs 2 | Steps 2580 | total loss : 1673382.125
epochs 2 | Steps 2600 | total loss : 1686315.625
epochs 2 | Steps 2620 | total loss : 1698688.625
epochs 2 | Steps 2640 | total loss : 1711351.375
epochs 2 | Steps 2660 | total loss : 1724494.0
epochs 2 | Steps 2680 | total loss : 1737759.5
epochs 2 | Steps 2700 | total loss : 1

HBox(children=(IntProgress(value=0, max=165), HTML(value='')))

epochs 2 | Steps 2820 | total loss : 1825941.125
epochs 2 | Steps 2840 | total loss : 1838038.875
epochs 2 | Steps 2860 | total loss : 1849496.375
epochs 2 | Steps 2880 | total loss : 1860905.625
epochs 2 | Steps 2900 | total loss : 1872191.25
epochs 2 | Steps 2920 | total loss : 1883488.75
epochs 2 | Steps 2940 | total loss : 1894726.75
epochs 2 | Steps 2960 | total loss : 1906096.625


HBox(children=(IntProgress(value=0, description='MUSIC', max=7, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=442), HTML(value='')))

epochs 3 | Steps 20 | total loss : 13181.439453125
epochs 3 | Steps 40 | total loss : 25949.244140625
epochs 3 | Steps 60 | total loss : 38477.26171875
epochs 3 | Steps 80 | total loss : 51145.21875
epochs 3 | Steps 100 | total loss : 63706.859375
epochs 3 | Steps 120 | total loss : 75801.2109375
epochs 3 | Steps 140 | total loss : 88054.734375
epochs 3 | Steps 160 | total loss : 100502.3125
epochs 3 | Steps 180 | total loss : 112922.109375
epochs 3 | Steps 200 | total loss : 125426.6640625
epochs 3 | Steps 220 | total loss : 137552.1875
epochs 3 | Steps 240 | total loss : 149657.640625
epochs 3 | Steps 260 | total loss : 162293.84375
epochs 3 | Steps 280 | total loss : 174538.734375
epochs 3 | Steps 300 | total loss : 186826.359375
epochs 3 | Steps 320 | total loss : 198982.71875
epochs 3 | Steps 340 | total loss : 211496.546875
epochs 3 | Steps 360 | total loss : 223793.578125
epochs 3 | Steps 380 | total loss : 235909.140625
epochs 3 | Steps 400 | total loss : 248089.328125
epochs 3

HBox(children=(IntProgress(value=0, max=532), HTML(value='')))

epochs 3 | Steps 460 | total loss : 284770.21875
epochs 3 | Steps 480 | total loss : 297286.40625
epochs 3 | Steps 500 | total loss : 309942.15625
epochs 3 | Steps 520 | total loss : 322207.09375
epochs 3 | Steps 540 | total loss : 334886.625
epochs 3 | Steps 560 | total loss : 346910.15625
epochs 3 | Steps 580 | total loss : 359363.625
epochs 3 | Steps 600 | total loss : 371823.25
epochs 3 | Steps 620 | total loss : 384290.8125
epochs 3 | Steps 640 | total loss : 396558.9375
epochs 3 | Steps 660 | total loss : 409078.1875
epochs 3 | Steps 680 | total loss : 421373.0625
epochs 3 | Steps 700 | total loss : 433507.125
epochs 3 | Steps 720 | total loss : 446233.84375
epochs 3 | Steps 740 | total loss : 458623.28125
epochs 3 | Steps 760 | total loss : 470888.78125
epochs 3 | Steps 780 | total loss : 483156.21875
epochs 3 | Steps 800 | total loss : 495462.34375
epochs 3 | Steps 820 | total loss : 507647.75
epochs 3 | Steps 840 | total loss : 519981.625
epochs 3 | Steps 860 | total loss : 53

HBox(children=(IntProgress(value=0, max=485), HTML(value='')))

epochs 3 | Steps 980 | total loss : 604990.1875
epochs 3 | Steps 1000 | total loss : 617856.25
epochs 3 | Steps 1020 | total loss : 630575.3125
epochs 3 | Steps 1040 | total loss : 643100.0625
epochs 3 | Steps 1060 | total loss : 655711.5625
epochs 3 | Steps 1080 | total loss : 667934.875
epochs 3 | Steps 1100 | total loss : 680637.875
epochs 3 | Steps 1120 | total loss : 692947.5625
epochs 3 | Steps 1140 | total loss : 705530.75
epochs 3 | Steps 1160 | total loss : 717841.3125
epochs 3 | Steps 1180 | total loss : 730076.875
epochs 3 | Steps 1200 | total loss : 742486.8125
epochs 3 | Steps 1220 | total loss : 755008.5
epochs 3 | Steps 1240 | total loss : 767343.25
epochs 3 | Steps 1260 | total loss : 779988.6875
epochs 3 | Steps 1280 | total loss : 792372.625
epochs 3 | Steps 1300 | total loss : 804411.3125
epochs 3 | Steps 1320 | total loss : 816740.25
epochs 3 | Steps 1340 | total loss : 828733.8125
epochs 3 | Steps 1360 | total loss : 840865.1875
epochs 3 | Steps 1380 | total loss :

HBox(children=(IntProgress(value=0, max=437), HTML(value='')))

epochs 3 | Steps 1460 | total loss : 900924.8125
epochs 3 | Steps 1480 | total loss : 913838.4375
epochs 3 | Steps 1500 | total loss : 926360.6875
epochs 3 | Steps 1520 | total loss : 938768.125
epochs 3 | Steps 1540 | total loss : 951350.25
epochs 3 | Steps 1560 | total loss : 963286.0625
epochs 3 | Steps 1580 | total loss : 975456.9375
epochs 3 | Steps 1600 | total loss : 987710.625
epochs 3 | Steps 1620 | total loss : 999762.8125
epochs 3 | Steps 1640 | total loss : 1012122.1875
epochs 3 | Steps 1660 | total loss : 1024355.0
epochs 3 | Steps 1680 | total loss : 1036374.875
epochs 3 | Steps 1700 | total loss : 1048480.4375
epochs 3 | Steps 1720 | total loss : 1060341.875
epochs 3 | Steps 1740 | total loss : 1072473.25
epochs 3 | Steps 1760 | total loss : 1084360.75
epochs 3 | Steps 1780 | total loss : 1096072.125
epochs 3 | Steps 1800 | total loss : 1107928.5
epochs 3 | Steps 1820 | total loss : 1119897.75
epochs 3 | Steps 1840 | total loss : 1131718.375
epochs 3 | Steps 1860 | total

HBox(children=(IntProgress(value=0, max=490), HTML(value='')))

epochs 3 | Steps 1900 | total loss : 1166749.75
epochs 3 | Steps 1920 | total loss : 1179536.125
epochs 3 | Steps 1940 | total loss : 1191897.125
epochs 3 | Steps 1960 | total loss : 1204382.5
epochs 3 | Steps 1980 | total loss : 1216612.25
epochs 3 | Steps 2000 | total loss : 1228885.625
epochs 3 | Steps 2020 | total loss : 1241233.5
epochs 3 | Steps 2040 | total loss : 1253501.625
epochs 3 | Steps 2060 | total loss : 1265585.875
epochs 3 | Steps 2080 | total loss : 1277817.0
epochs 3 | Steps 2100 | total loss : 1290214.625
epochs 3 | Steps 2120 | total loss : 1302406.0
epochs 3 | Steps 2140 | total loss : 1314602.625
epochs 3 | Steps 2160 | total loss : 1326801.875
epochs 3 | Steps 2180 | total loss : 1338828.875
epochs 3 | Steps 2200 | total loss : 1350887.125
epochs 3 | Steps 2220 | total loss : 1363139.875
epochs 3 | Steps 2240 | total loss : 1375314.375
epochs 3 | Steps 2260 | total loss : 1387236.375
epochs 3 | Steps 2280 | total loss : 1399450.875
epochs 3 | Steps 2300 | total 

HBox(children=(IntProgress(value=0, max=472), HTML(value='')))

epochs 3 | Steps 2400 | total loss : 1471432.625
epochs 3 | Steps 2420 | total loss : 1484008.5
epochs 3 | Steps 2440 | total loss : 1496804.0
epochs 3 | Steps 2460 | total loss : 1509380.5
epochs 3 | Steps 2480 | total loss : 1521795.0
epochs 3 | Steps 2500 | total loss : 1534488.125
epochs 3 | Steps 2520 | total loss : 1546713.25
epochs 3 | Steps 2540 | total loss : 1559072.75
epochs 3 | Steps 2560 | total loss : 1571229.25
epochs 3 | Steps 2580 | total loss : 1583367.5
epochs 3 | Steps 2600 | total loss : 1595516.25
epochs 3 | Steps 2620 | total loss : 1607692.875
epochs 3 | Steps 2640 | total loss : 1619534.625
epochs 3 | Steps 2660 | total loss : 1631524.5
epochs 3 | Steps 2680 | total loss : 1643528.25
epochs 3 | Steps 2700 | total loss : 1655337.0
epochs 3 | Steps 2720 | total loss : 1667097.875
epochs 3 | Steps 2740 | total loss : 1679208.375
epochs 3 | Steps 2760 | total loss : 1690990.375
epochs 3 | Steps 2780 | total loss : 1702850.75
epochs 3 | Steps 2800 | total loss : 171

HBox(children=(IntProgress(value=0, max=121), HTML(value='')))

epochs 3 | Steps 2860 | total loss : 1749551.5
epochs 3 | Steps 2880 | total loss : 1762460.125
epochs 3 | Steps 2900 | total loss : 1774956.625
epochs 3 | Steps 2920 | total loss : 1787382.125
epochs 3 | Steps 2940 | total loss : 1799282.625
epochs 3 | Steps 2960 | total loss : 1811083.25


HBox(children=(IntProgress(value=0, description='MUSIC', max=7, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=574), HTML(value='')))

epochs 4 | Steps 20 | total loss : 12433.9296875
epochs 4 | Steps 40 | total loss : 24493.49609375
epochs 4 | Steps 60 | total loss : 36138.12890625
epochs 4 | Steps 80 | total loss : 47778.7265625
epochs 4 | Steps 100 | total loss : 59683.421875
epochs 4 | Steps 120 | total loss : 71748.40625
epochs 4 | Steps 140 | total loss : 83300.390625
epochs 4 | Steps 160 | total loss : 94971.1875
epochs 4 | Steps 180 | total loss : 106908.1171875
epochs 4 | Steps 200 | total loss : 118763.5703125
epochs 4 | Steps 220 | total loss : 130389.2109375
epochs 4 | Steps 240 | total loss : 142034.03125
epochs 4 | Steps 260 | total loss : 153526.515625
epochs 4 | Steps 280 | total loss : 165144.296875
epochs 4 | Steps 300 | total loss : 176800.328125
epochs 4 | Steps 320 | total loss : 188320.703125
epochs 4 | Steps 340 | total loss : 199939.015625
epochs 4 | Steps 360 | total loss : 211617.3125
epochs 4 | Steps 380 | total loss : 223304.28125
epochs 4 | Steps 400 | total loss : 235077.375
epochs 4 | St

HBox(children=(IntProgress(value=0, max=396), HTML(value='')))

epochs 4 | Steps 580 | total loss : 337054.9375
epochs 4 | Steps 600 | total loss : 349279.75
epochs 4 | Steps 620 | total loss : 360887.71875
epochs 4 | Steps 640 | total loss : 372306.4375
epochs 4 | Steps 660 | total loss : 383884.1875
epochs 4 | Steps 680 | total loss : 395023.09375
epochs 4 | Steps 700 | total loss : 406229.1875
epochs 4 | Steps 720 | total loss : 417472.65625
epochs 4 | Steps 740 | total loss : 428399.21875
epochs 4 | Steps 760 | total loss : 439494.9375
epochs 4 | Steps 780 | total loss : 450509.3125
epochs 4 | Steps 800 | total loss : 461696.53125
epochs 4 | Steps 820 | total loss : 472653.875
epochs 4 | Steps 840 | total loss : 483839.8125
epochs 4 | Steps 860 | total loss : 495073.375
epochs 4 | Steps 880 | total loss : 506146.90625
epochs 4 | Steps 900 | total loss : 517150.78125
epochs 4 | Steps 920 | total loss : 527771.0
epochs 4 | Steps 940 | total loss : 538716.5
epochs 4 | Steps 960 | total loss : 549607.4375


HBox(children=(IntProgress(value=0, max=561), HTML(value='')))

epochs 4 | Steps 980 | total loss : 560656.3125
epochs 4 | Steps 1000 | total loss : 572649.1875
epochs 4 | Steps 1020 | total loss : 584572.6875
epochs 4 | Steps 1040 | total loss : 596567.25
epochs 4 | Steps 1060 | total loss : 608396.5625
epochs 4 | Steps 1080 | total loss : 619879.0625
epochs 4 | Steps 1100 | total loss : 631510.8125
epochs 4 | Steps 1120 | total loss : 643290.0
epochs 4 | Steps 1140 | total loss : 655113.625
epochs 4 | Steps 1160 | total loss : 666616.75
epochs 4 | Steps 1180 | total loss : 678169.0625
epochs 4 | Steps 1200 | total loss : 689514.375
epochs 4 | Steps 1220 | total loss : 701055.0
epochs 4 | Steps 1240 | total loss : 712156.5
epochs 4 | Steps 1260 | total loss : 723838.0
epochs 4 | Steps 1280 | total loss : 735109.8125
epochs 4 | Steps 1300 | total loss : 746526.875
epochs 4 | Steps 1320 | total loss : 758212.875
epochs 4 | Steps 1340 | total loss : 769833.25
epochs 4 | Steps 1360 | total loss : 781425.875
epochs 4 | Steps 1380 | total loss : 792802.

HBox(children=(IntProgress(value=0, max=438), HTML(value='')))

epochs 4 | Steps 1540 | total loss : 883765.625
epochs 4 | Steps 1560 | total loss : 896241.4375
epochs 4 | Steps 1580 | total loss : 908111.5625
epochs 4 | Steps 1600 | total loss : 920245.6875
epochs 4 | Steps 1620 | total loss : 931852.9375
epochs 4 | Steps 1640 | total loss : 943570.5
epochs 4 | Steps 1660 | total loss : 955286.0
epochs 4 | Steps 1680 | total loss : 966797.125
epochs 4 | Steps 1700 | total loss : 978408.9375
epochs 4 | Steps 1720 | total loss : 989985.5625
epochs 4 | Steps 1740 | total loss : 1001496.0625
epochs 4 | Steps 1760 | total loss : 1012729.125
epochs 4 | Steps 1780 | total loss : 1023968.125
epochs 4 | Steps 1800 | total loss : 1035105.875
epochs 4 | Steps 1820 | total loss : 1046542.9375
epochs 4 | Steps 1840 | total loss : 1057760.25
epochs 4 | Steps 1860 | total loss : 1068894.875
epochs 4 | Steps 1880 | total loss : 1080307.0
epochs 4 | Steps 1900 | total loss : 1091609.5
epochs 4 | Steps 1920 | total loss : 1102662.5
epochs 4 | Steps 1940 | total los

HBox(children=(IntProgress(value=0, max=563), HTML(value='')))

epochs 4 | Steps 1980 | total loss : 1136877.375
epochs 4 | Steps 2000 | total loss : 1149213.125
epochs 4 | Steps 2020 | total loss : 1161579.75
epochs 4 | Steps 2040 | total loss : 1173903.375
epochs 4 | Steps 2060 | total loss : 1186014.5
epochs 4 | Steps 2080 | total loss : 1198044.875
epochs 4 | Steps 2100 | total loss : 1209995.25
epochs 4 | Steps 2120 | total loss : 1221700.875
epochs 4 | Steps 2140 | total loss : 1233553.25
epochs 4 | Steps 2160 | total loss : 1245360.0
epochs 4 | Steps 2180 | total loss : 1257384.125
epochs 4 | Steps 2200 | total loss : 1269449.875
epochs 4 | Steps 2220 | total loss : 1281523.25
epochs 4 | Steps 2240 | total loss : 1292881.5
epochs 4 | Steps 2260 | total loss : 1304578.375
epochs 4 | Steps 2280 | total loss : 1316106.375
epochs 4 | Steps 2300 | total loss : 1327644.375
epochs 4 | Steps 2320 | total loss : 1339247.5
epochs 4 | Steps 2340 | total loss : 1350815.125
epochs 4 | Steps 2360 | total loss : 1362438.625
epochs 4 | Steps 2380 | total lo

HBox(children=(IntProgress(value=0, max=331), HTML(value='')))

epochs 4 | Steps 2540 | total loss : 1464390.25
epochs 4 | Steps 2560 | total loss : 1476085.875
epochs 4 | Steps 2580 | total loss : 1487888.875
epochs 4 | Steps 2600 | total loss : 1499180.125
epochs 4 | Steps 2620 | total loss : 1510702.125
epochs 4 | Steps 2640 | total loss : 1522059.875
epochs 4 | Steps 2660 | total loss : 1533228.25
epochs 4 | Steps 2680 | total loss : 1544543.625
epochs 4 | Steps 2700 | total loss : 1555718.25
epochs 4 | Steps 2720 | total loss : 1567049.625
epochs 4 | Steps 2740 | total loss : 1577951.25
epochs 4 | Steps 2760 | total loss : 1588939.875
epochs 4 | Steps 2780 | total loss : 1599866.0
epochs 4 | Steps 2800 | total loss : 1610876.375
epochs 4 | Steps 2820 | total loss : 1621855.375
epochs 4 | Steps 2840 | total loss : 1632906.125
epochs 4 | Steps 2860 | total loss : 1643868.75


HBox(children=(IntProgress(value=0, max=118), HTML(value='')))

epochs 4 | Steps 2880 | total loss : 1655025.0
epochs 4 | Steps 2900 | total loss : 1666277.375
epochs 4 | Steps 2920 | total loss : 1676952.5
epochs 4 | Steps 2940 | total loss : 1687610.75
epochs 4 | Steps 2960 | total loss : 1697723.5
epochs 4 | Steps 2980 | total loss : 1708068.625


In [61]:
model.save('model_ep4.h5')
pickle.dump( note_tokenizer, open( "tokenizer.p", "wb" ) )

In [62]:
model = tf.keras.models.load_model('model_ep4.h5', custom_objects=SeqSelfAttention.get_custom_objects())
note_tokenizer  = pickle.load( open( "tokenizer.p", "rb" ) )



In [63]:
def generate_from_random(unique_notes, seq_len=50):
  generate = np.random.randint(0,unique_notes,seq_len).tolist()
  return generate
    
def generate_from_one_note(note_tokenizer, new_notes='35'):
  generate = [note_tokenizer.notes_to_index['e'] for i in range(49)]
  generate += [note_tokenizer.notes_to_index[new_notes]]
  return generate

In [72]:
def piano_roll_to_pretty_midi(piano_roll, fs=100, program=0):
    '''Convert a Piano Roll array into a PrettyMidi object
     with a single instrument.
    Parameters
    ----------
    piano_roll : np.ndarray, shape=(128,frames), dtype=int
        Piano roll of one instrument
    fs : int
        Sampling frequency of the columns, i.e. each column is spaced apart
        by ``1./fs`` seconds.
    program : int
        The program number of the instrument.
    Returns
    -------
    midi_object : pretty_midi.PrettyMIDI
        A pretty_midi.PrettyMIDI class instance describing
        the piano roll.
    '''
    notes, frames = piano_roll.shape
    pm = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(program=program)

    # pad 1 column of zeros so we can acknowledge inital and ending events
    piano_roll = np.pad(piano_roll, [(0, 0), (1, 1)], 'constant')

    # use changes in velocities to find note on / note off events
    velocity_changes = np.nonzero(np.diff(piano_roll).T)

    # keep track on velocities and note on times
    prev_velocities = np.zeros(notes, dtype=int)
    note_on_time = np.zeros(notes)

    for time, note in zip(*velocity_changes):
        # use time + 1 because of padding above
        velocity = piano_roll[note, time + 1]
        time = time / fs
        if velocity > 0:
            if prev_velocities[note] == 0:
                note_on_time[note] = time
                prev_velocities[note] = velocity
        else:
            pm_note = pretty_midi.Note(
                velocity=prev_velocities[note],
                pitch=note,
                start=note_on_time[note],
                end=time)
            instrument.notes.append(pm_note)
            prev_velocities[note] = 0
    pm.instruments.append(instrument)
    return pm


In [73]:
from numpy.random import choice
def generate_notes(generate, model, unique_notes, max_generated=1000, seq_len=50):
  for i in tqdm_notebook(range(max_generated), desc='genrt'):
    test_input = np.array([generate])[:,i:i+seq_len]
    predicted_note = model.predict(test_input)
    random_note_pred = choice(unique_notes+1, 1, replace=False, p=predicted_note[0])
    generate.append(random_note_pred[0])
  return generate

In [74]:
def write_midi_file_from_generated(generate, midi_file_name = "result.mid", start_index=49, fs=8, max_generated=1000):
  note_string = [note_tokenizer.index_to_notes[ind_note] for ind_note in generate]
  array_piano_roll = np.zeros((128,max_generated+1), dtype=np.int16)
  for index, note in enumerate(note_string[start_index:]):
    if note == 'e':
      pass
    else:
      splitted_note = note.split(',')
      for j in splitted_note:
        array_piano_roll[int(j),index] = 1
  generate_to_midi = piano_roll_to_pretty_midi(array_piano_roll, fs=fs)
  print("Tempo {}".format(generate_to_midi.estimate_tempo()))
  for note in generate_to_midi.instruments[0].notes:
    note.velocity = 100
  generate_to_midi.write(midi_file_name)

In [101]:

# generate random integer values
from random import seed
from random import randint
# seed random number generator
#seed(9)
value = randint(0, 100)
print(value)
name='random'+str(value)+'.mid'
name1='one_note'+str(value)+'.mid'
print(name)

0
random0.mid


In [100]:
max_generate = 200
unique_notes = note_tokenizer.unique_word
seq_len=50
generate = generate_from_random(unique_notes, seq_len)
generate = generate_notes(generate, model, unique_notes, max_generate, seq_len)
write_midi_file_from_generated(generate, name, start_index=seq_len-1, fs=7, max_generated = max_generate)

HBox(children=(IntProgress(value=0, description='genrt', max=200, style=ProgressStyle(description_width='initi…

Tempo 209.4117647058822


In [103]:
max_generate = 300
unique_notes = note_tokenizer.unique_word
seq_len=50
generate = generate_from_one_note(note_tokenizer, '72')
generate = generate_notes(generate, model, unique_notes, max_generate, seq_len)
write_midi_file_from_generated(generate, name1, start_index=seq_len-1, fs=8, max_generated = max_generate)

HBox(children=(IntProgress(value=0, description='genrt', max=300, style=ProgressStyle(description_width='initi…

Tempo 236.46315789473675
