In [415]:
! pip install pretty_midi

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [428]:
import numpy as np
import scipy.signal

import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import LSTM, Dropout, Dense, BatchNormalization
import keras.backend as K

In [429]:
import glob
import urllib.request
import pretty_midi

import matplotlib.pyplot as plt
import time
import ipywidgets as widgets
from IPython.display import clear_output
import IPython

In [430]:
n_x = 8
nb_units = 256
num_layers = 3*2
batch_size = 50

sequence_len = 20

step = 0.3
sleep = 7
output = widgets.Output()

mu = []
actions = []
observations = []

In [431]:
DIR = './'

midiFile_l = ['cs1-2all.mid', 'cs5-1pre.mid', 'cs4-1pre.mid', 'cs3-5bou.mid', 'cs1-4sar.mid', 'cs2-5men.mid', 'cs3-3cou.mid', 'cs2-3cou.mid', 'cs1-6gig.mid', 'cs6-4sar.mid', 'cs4-5bou.mid', 'cs4-3cou.mid', 'cs5-3cou.mid', 'cs6-5gav.mid', 'cs6-6gig.mid', 'cs6-2all.mid', 'cs2-1pre.mid', 'cs3-1pre.mid', 'cs3-6gig.mid', 'cs2-6gig.mid', 'cs2-4sar.mid', 'cs3-4sar.mid', 'cs1-5men.mid', 'cs1-3cou.mid', 'cs6-1pre.mid', 'cs2-2all.mid', 'cs3-2all.mid', 'cs1-1pre.mid', 'cs5-2all.mid', 'cs4-2all.mid', 'cs5-5gav.mid', 'cs4-6gig.mid', 'cs5-6gig.mid', 'cs5-4sar.mid', 'cs4-4sar.mid', 'cs6-3cou.mid']
for midiFile in midiFile_l:
  #if os.path.isfile(DIR + midiFile) is None:
  urllib.request.urlretrieve ("http://www.jsbach.net/midi/" + midiFile, DIR + midiFile)
nbExample = len(midiFile_l)

midiFile_l = glob.glob(DIR + 'cs*.mid')

X_list = []

for midiFile in midiFile_l:
    # read the MIDI file
    midi_data = pretty_midi.PrettyMIDI(midiFile)
    note_l = [note.pitch for note in midi_data.instruments[0].notes]
    
    X_list.append(note_l)

X_train_list = []
y_train_list = []


for X in X_list:
  for i in range(len(X)-sequence_len):
    array = np.array(X[i:1+i+sequence_len])
    array -= min(array)
    if max(array)<=7:
      x = np.zeros((sequence_len, n_x))
      y = np.zeros((n_x,))
      for e in range(sequence_len):
        x[e, array[e]] = 1
      
      y[array[sequence_len]] = 1
      X_train_list.append(x)
      y_train_list.append(y)

songs = np.asarray(X_train_list)
y_train = np.asarray(y_train_list)

In [432]:
class Reward(Model):
  def __init__(self, songs=songs):
    super(Reward, self).__init__()
    self.songs = songs
    self.build_model()
  
  def build_model(self):
    self.lstm1 = LSTM(nb_units, input_shape=(sequence_len, n_x), return_sequences=True)
    self.batch_norm1 = BatchNormalization()
    self.dropout1 = Dropout(0.2)
    self.lstm2 = LSTM(nb_units, return_sequences=True)
    self.batch_norm2 = BatchNormalization()
    self.dropout2 = Dropout(0.2)
    self.lstm3 = LSTM(nb_units, return_sequences=True)
    self.batch_norm3 = BatchNormalization()
    self.dropout3 = Dropout(0.2)
    self.dense = Dense(n_x, activation='sigmoid')
 
  def call_rewards(self, inputs):
    observations, actions = inputs

    x = self.lstm1(observations)
    x = self.batch_norm1(x)
    x = self.dropout1(x)
    x = self.lstm2(x)
    x = self.batch_norm2(x)
    x = self.dropout2(x)
    x = self.lstm3(x)
    x = self.dense(x)

    x = self.lstm1(tf.math.softmax(x))
    x = self.batch_norm1(x)
    x = self.lstm2(x)
    x = self.batch_norm2(x)
    x = self.lstm3(x)
    x = self.batch_norm3(x)
    x = self.dense(x)

    x = tf.reshape(x, (-1, n_x))
    actions = tf.reshape(actions, (-1,))
    x = tf.gather(x, actions, axis=-1, batch_dims=1)
    x = tf.reshape(x, (-1, sequence_len))

    return x
  
  def reset(self, n):
    return np.array(self.songs[np.random.randint(len(self.songs),  size=(n,))])
  
  def call(self, inputs):
    observations, actions = inputs

    actions1 = actions[:,0,:]
    actions2 = actions[:,1,:]

    x1 = self.call_rewards([observations, actions1])
    x2 = self.call_rewards([observations, actions2])

    x1 = tf.math.reduce_sum(x1, axis=-1)[:,None]
    x2 = tf.math.reduce_sum(x2, axis=-1)[:,None]

    x = tf.concat([x1, x2], axis=1)

    return tf.nn.softmax(x)

In [433]:
discount = 0.95

class Policy(Model):
  def __init__(self):
    super(Policy, self).__init__()
    self.build_model()
    
  def build_model(self):
    self.lstm1 = LSTM(nb_units, input_shape=(sequence_len, n_x), return_sequences=True)
    self.batch_norm1 = BatchNormalization()
    self.dropout1 = Dropout(0.2)
    self.lstm2 = LSTM(nb_units, return_sequences=True)
    self.batch_norm2 = BatchNormalization()
    self.dropout2 = Dropout(0.2)
    self.lstm3 = LSTM(nb_units, return_sequences=True)
    self.batch_norm3 = BatchNormalization()
    self.dropout3 = Dropout(0.2)
    self.dense = Dense(n_x, activation='softmax')

  def call(self, inputs):
    x = self.lstm1(inputs)
    x = self.batch_norm1(x)
    x = self.lstm2(x)
    x = self.batch_norm2(x)
    x = self.lstm3(x)
    x = self.batch_norm3(x)
    x = self.dense(x)

    x = self.lstm1(x)
    x = self.batch_norm1(x)
    x = self.dropout1(x)
    x = self.lstm2(x)
    x = self.batch_norm2(x)
    x = self.dropout2(x)
    x = self.lstm3(x)
    x = self.batch_norm3(x)
    x = self.dropout3(x)
    x = self.dense(x)

    return x
  
  def get_queries(self):
    observations = re.reset(1)
    probs = self.call(observations)
    probs = tf.reshape(probs, (-1,n_x))
    actions = tf.random.categorical(probs, 2)
    actions = tf.reshape(actions, (2,-1))

    return actions, observations

  def hf_train(self):
    optimizer = tf.keras.optimizers.Adam()

    observations = re.reset(batch_size)

    with tf.GradientTape() as tape:
      probs = self.call(observations)
      probs = tf.reshape(probs, (-1,n_x))

      actions = tf.random.categorical(probs, 1)
      actions = tf.reshape(actions, (-1,))

      rewards = re.call_rewards([observations, tf.reshape(actions, (-1, sequence_len))])
      # rewards = scipy.signal.lfilter([1.], [1, -discount], rewards[::-1])[::-1]

      ls = tf.gather(probs, actions, axis=-1, batch_dims=1)
      ls = tf.reshape(ls, (batch_size, sequence_len))
      ls = -tf.math.log(ls)
      
      ls = tf.multiply(ls, rewards)
      ls = tf.reduce_mean(ls)
    
    variables = self.trainable_variables
    gradients = tape.gradient(ls, variables)
    optimizer.apply_gradients(zip(gradients, variables))
      
    return ls.numpy()

In [448]:
pi = Policy()

for i in range(200):
  print(f"Training policy: Epoch {i+1}", end=" Loss ")
  print(pi.hf_train())

Training policy: Epoch 1 Loss 3.3054552
Training policy: Epoch 2 Loss 3.3185968
Training policy: Epoch 3 Loss 3.1965456
Training policy: Epoch 4 Loss 3.3145874
Training policy: Epoch 5 Loss 3.26048
Training policy: Epoch 6 Loss 3.3008497
Training policy: Epoch 7 Loss 3.2542713
Training policy: Epoch 8 Loss 3.2228897
Training policy: Epoch 9 Loss 3.312581
Training policy: Epoch 10 Loss 3.4033296
Training policy: Epoch 11 Loss 3.4935884
Training policy: Epoch 12 Loss 3.2624037
Training policy: Epoch 13 Loss 3.432679
Training policy: Epoch 14 Loss 3.4637399
Training policy: Epoch 15 Loss 3.2250695
Training policy: Epoch 16 Loss 3.2347455
Training policy: Epoch 17 Loss 3.1763306
Training policy: Epoch 18 Loss 3.150961
Training policy: Epoch 19 Loss 3.1337123
Training policy: Epoch 20 Loss 3.269529
Training policy: Epoch 21 Loss 3.1617286
Training policy: Epoch 22 Loss 3.3474689
Training policy: Epoch 23 Loss 3.4256868
Training policy: Epoch 24 Loss 3.2736552
Training policy: Epoch 25 Loss 

In [435]:
class HumanFeedback():
  def get_midi_data(self, actions_):
    data = []
    for e in [0,1]:
      note_l = actions_[e]
      new_midi_data = pretty_midi.PrettyMIDI()
      cello_program = pretty_midi.instrument_name_to_program('Cello')
      cello = pretty_midi.Instrument(program=cello_program)
      t = 0
      for note_number in note_l:
          myNote = pretty_midi.Note(velocity=100, pitch=int(note_number)+70, start=t, end=t+step)
          cello.notes.append(myNote)
          t += step
      new_midi_data.instruments.append(cello)
      data.append(new_midi_data.synthesize(fs=44100))
    return data
      
  def display(self, actions_, observations_):
    global mu, actions, observations

    data = self.get_midi_data(actions_)

    display_1 = IPython.display.Audio(data[0], autoplay=True, rate=44100)
    display_2 = IPython.display.Audio(data[1], autoplay=True, rate=44100)

    IPython.display.display(display_1)
    time.sleep(sleep)
    IPython.display.display(display_2)

    m = input()

    if m == '1':
      mu.append([1,0])
    if m == '2':
      mu.append([0,1])
    
    actions.append(np.asarray(actions_).tolist())
    observations.append(np.asarray(observations_).tolist())

    clear_output()

In [441]:
hf = HumanFeedback()

pi = Policy()

re = Reward()
re.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

In [451]:
good_actions = np.argmax(songs, axis=-1)
bad_actions = np.random.randint(n_x, size=good_actions.shape)

actions = np.zeros((good_actions.shape[0],2,good_actions.shape[1]))

mu = np.eye(2, dtype=int)[np.random.choice(2, good_actions.shape[0])]
mu = mu.tolist()

for i in range(good_actions.shape[0]):
  if mu[i][0] == 1:
    actions[i,0,:] = good_actions[i]
    actions[i,1,:] = bad_actions[i]
  else:
    actions[i,1,:] = bad_actions[i]
    actions[i,0,:] = good_actions[i]

actions = actions.astype(int).tolist()

observations = re.reset(good_actions.shape[0])
observations = observations.reshape((good_actions.shape[0], -1, n_x)).tolist()

In [452]:
# training using synthetic labels
re.fit([np.asarray(observations),np.asarray(actions)], np.asarray(mu), epochs=5, batch_size=24)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7f3d910a1070>

In [444]:
for i in range(20):
  print(f"Training policy: Epoch {i+1}", end=" Loss ")
  print(pi.hf_train())

Training policy: Epoch 1 Loss 8.708883
Training policy: Epoch 2 Loss 8.840943
Training policy: Epoch 3 Loss 8.651164
Training policy: Epoch 4 Loss 8.678113
Training policy: Epoch 5 Loss 8.662377
Training policy: Epoch 6 Loss 8.742814
Training policy: Epoch 7 Loss 8.582878
Training policy: Epoch 8 Loss 8.75246
Training policy: Epoch 9 Loss 8.660283
Training policy: Epoch 10 Loss 8.62584
Training policy: Epoch 11 Loss 8.8733015
Training policy: Epoch 12 Loss 8.754227
Training policy: Epoch 13 Loss 8.676853
Training policy: Epoch 14 Loss 8.824286
Training policy: Epoch 15 Loss 8.890201
Training policy: Epoch 16 Loss 8.595139
Training policy: Epoch 17 Loss 8.725646
Training policy: Epoch 18 Loss 8.693422
Training policy: Epoch 19 Loss 8.779534
Training policy: Epoch 20 Loss 8.727784


In [449]:
for i in range(5):
  a, o = pi.get_queries()
  hf.display(a, o)

KeyboardInterrupt: ignored