<a href="https://colab.research.google.com/github/rsaxby/NoteRNN/blob/master/noteRNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RNN for Music Generation

Import PyTorch

In [0]:
# http://pytorch.org/
from os.path import exists
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'
!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.1-{platform}-linux_x86_64.whl torchvision
import torch

Mount Google Drive to get data


In [0]:
# Load the Drive helper and mount
from google.colab import drive
drive.mount('/content/drive')

Import libraries and check if we can train on GPU

In [0]:
# import libraries
import numpy as np
import pandas as pd 
import torch
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from torch import nn
import torch.nn.functional as F
!pip install music21
from music21 import *
!pip install pygame
import pygame
from google.colab import files
#configure.run()
import glob
from torch import optim
from torchvision import datasets

# check if CUDA is available
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')

# Prepare the Dataset
The `Dataset` class creates a dataset from MIDI files in the directory specified by `data_dir`. Here we extract and sequentially store note, chord, and rest objects from the MIDI file stream in a list (to be fed to the network for training). Each object is also stored in a dictionary music21_objects, which we will later use to create MIDI events from using our generated notes.
Within the `create_dataset` method, we:


1.   Extract and store all note/chord/rest objects
2.   Create a dictionary to encode our unique objects
3.   Encode our list of notes to be used for training

In [0]:
class Dataset():
  def __init__(self, data_dir):
    self.notes = [] # list of all extracted notes from songs (string form)
    self.music21_objects = {} # list of all note objects from songs
    self.unique_notes = None # set of unique notes
    self.data_dir = data_dir # data directory where midi files are stored/saved
    self.pitch2int = self.pitch_to_int() # dict of unique notes & chords
    self.num_classes = None # num of unique notes
    
  # takes in a directory, extracts notes from all songs 
  # returns all notes in a single array
  def create_dataset(self):
    num_files = len([name for name in glob.glob(data_dir+"/*.MID")])
    count = 1
    for file in glob.glob(self.data_dir+"/*.MID"):
      print("\n{}/{} Processing {}...\n".format(count, num_files, file.strip(self.data_dir),))
      try:
          midi = self.open_midi(file)
          self.extract_notes(midi)
          
      except:
          print("Could not process: {}".format(file))
      count += 1 
    self.pitch_to_int()
    self.encode_notes()
     
  # extract notes from a midi file
  # return notes as string or note object
  def extract_notes(self, midi):
    notes_to_parse = None
    
    # get a list of all the notes and chords in the file
    parts = instrument.partitionByInstrument(midi)
    
    if parts: # if the file has instrument parts
      notes_to_parse = list(parts.parts[0].recurse())
    else: # file has notes in a flat structure
      notes_to_parse = midi.flat.notes
    
    for el in notes_to_parse: 
      if isinstance(el, note.Rest):
        # encode the rest
        encoded_rest = self.encode_rest(el.fullName)
        self.notes.append(encoded_rest)
        # add music21 object to note dictionary
        self.music21_objects[el.fullName] = el

      elif isinstance(el, note.Note):
        # append the note using its full name (str) to the list of notes
        self.notes.append(el.fullName)
        # append the note object to a list of notes for encoded_rests          
        self.music21_objects[el.fullName] = el

      elif isinstance(el, chord.Chord):
          self.notes.append(el.commonName)
          self.music21_objects[el.commonName] = el 
        
  # encode prev 10 note history in rests
  def encode_rest(self, rest):
    encoded_rest = ""
    # if the song starts with a rest, don't
    # encode the rest
    if len(self.notes) == 0:
      return rest
    # get the previous 10 notes to
    # be encoded with this rest
    if len(self.notes) >= 10:
      rng = -10
    elif len(self.notes) < 10: # if the length of notes is less than 10, use all notes
      rng = 0
    for nt in self.notes[rng:]:
      encoded_rest += nt
    return encoded_rest +" $"+ rest
  
  # open midi file
  def open_midi(self, file_name):
    return converter.parse(file_name)

  # list instruments: Takes in a midi file
  # and prints instruments in the file
  def list_instruments(self, midi):
    # start part stream
    partStream = midi.parts.stream()
    print("Instruments on MIDI file:")
    for part in partStream:
        print(part.partName)

  # analyze timeSignature and music keys
  def analyze_song(self, midi):
      # get the time signatures
      timeSig = midi.getTimeSignatures()[0]
      # get the key
      musicAnalysis = midi.analyze('key')
      print("Time signature: {0}/{1}".format(timeSig.beatCount, timeSig.denominator))
      print("Expected music key: {0}".format(musicAnalysis))
      print("Music key confidence: {0}".format(musicAnalysis.correlationCoefficient))

  # play midi file
  def play_song(self, midi_file):
      print("Playing MIDI...")
      song = midi.realtime.StreamPlayer(midi_file)
      song.play()

  # save midi file
  def save_song(self, midi, file_name):
      midi.write('mid', fp=self.data_dir+file_name)


  # takes in a str and creates a music21 note object
  def create_note(self, note):
      return pitch.Pitch(note).midi
    
  # encode a note
  def pitch_to_int(self):
      # get unique pitch names
      self.unique_notes=sorted(set(item for item in self.notes))
      # map pitches to ints
      # note will be the key
      self.pitch2int = dict((note, number) for number, note in enumerate(self.unique_notes))
      self.num_classes = len(self.pitch2int)
  
  # encode notes in a song
  def encode_notes(self):
      encoded_notes = []
      # for each note in the song, encode it as an int, and add it to
      # the encoded list
      for i in range(0, len(self.notes)):
          encoded_notes.append(self.pitch2int[self.notes[i]])
      self.encoded_notes = np.array(encoded_notes)
      
  # save extracted notes as CSV for easy upload/access
  def save_notes(self):
    df = pd.DataFrame(self.notes)
    df.to_csv("notes.csv", header=None, index=None)


Below are functions we will use during training:


*   `to_categorical` : one-hot encodes a batch 
*   `get_batches` : creates a batch of notes to be fed to the network (currently predicting 3 notes at a time)
*   `save_model` : saves the model during training



In [0]:
# helper functions 

def to_categorical(x, num_classes):
    """ 1-hot encodes a tensor """
    return np.eye(num_classes, dtype=np.float32)[x]
  

def get_batches(arr, batch_size, seq_length):  
  
  '''          
  Generator which creates batches of dim (batch_size * seq_length)
  batch_size = num sequences in a batch
  seq_length = num of notes in a sequence
  '''
  batch_size_total = batch_size * seq_length
  # total number of batches we can make
  n_batches = len(arr)//batch_size_total

  # keep only enough notes to make full batches
  arr = arr[:n_batches * batch_size_total]
  # reshape so we have as many rows/sequences as batch_size
  arr = arr.reshape((batch_size, -1))

  # iterate through the arr (cols), one sequence at a time
  for n in range(0, arr.shape[1], seq_length):
      # features
      x = arr[:, n:n+seq_length]
      # targets
      y = np.zeros_like(x)
      try:
          # targets are shifted by one for single note pred
          # try: y =-3 for multiple, x=2
          y[:, :-2], y[:, -2] = x[:, 2:], arr[:, n+seq_length-3]
      except IndexError:
          # grab last target values for single note pred
          # try: -3 for multiple, x=2
          y[:, :-2], y[:, -2] = x[:, 2:], arr[:, 3]
      yield x, y

# save the model
def save_model(best=False):
  if best:
    model_name = 'best_noteRNN.net'
    print("Saving Best Perf Model....")
  else:
    model_name = 'noteRNN.net'

  checkpoint = {'state_dict': net.state_dict(),
                'input_size': net.num_classes,
                'output_size': net.num_classes,
                'criterion_state': criterion.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'epochs': num_epochs}
  try:
    torch.save(checkpoint, model_name)
  except:
    print("Unable to save..")

The class NoteRNN is an LSTM initialized with:


*   `dropout` : specify percent for dropout - 0.3 (30%) is the default
*   `num_layers`: number of hidden layers - 1 is the default
*  `num_hidden`: number of hidden units - 256 is the default
*  `lr`: learning rate - 0.003 is the default
*  `num_classes`: equal to the number of `unique_notes` in our dataset




In [0]:
class NoteRNN(nn.Module):
    
    def __init__(self, unique_notes, num_hidden=256, num_layers=3,
                               dropout=0.3, lr=0.003):
        super().__init__()
        self.dropout = dropout
        self.num_layers = num_layers
        self.num_hidden = num_hidden
        self.lr = lr
        self.num_classes = len(unique_notes)
            
        # note dictionaries
        self.note2int = dict((note, number) for number, note in enumerate(unique_notes))
        self.int2note = dict(enumerate(self.note2int))
        
        # define LSTM
        self.lstm = nn.LSTM(self.num_classes, num_hidden, num_layers, 
                            dropout=dropout, batch_first=True)
        
        # define dropout layer
        self.dropout = nn.Dropout(dropout)
        
        # define (fully connected) output layer
        self.fc = nn.Linear(num_hidden, self.num_classes)
      
    
    def forward(self, x, hidden):
        ''' Forward pass through the network. 
            These inputs are x, and the hidden/cell state `hidden`. '''
                
        # get outputs and new hidden state from LSTM
        r_output, hidden = self.lstm(x, hidden)
        
        # dropout layer
        out = self.dropout(r_output)
        
        # Stack up LSTM outputs using view
        # for multiple, use contiguous to reshape the output
        out = out.contiguous().view(-1, self.num_hidden)
        
        # fully-connected layer
        out = self.fc(out)
        
        # return final output and hidden state
        return out, hidden
    
    
    def init_hidden(self, batch_size):
        # create 2 tensors of dim: (n_layers x batch_size x n_hidden)
        # initialized to zero, for hidden state and cell state of LSTM
        weight = next(self.parameters()).data
        
        if (train_on_gpu):
            hidden = (weight.new(self.num_layers, batch_size, self.num_hidden).zero_().cuda(),
                  weight.new(self.num_layers, batch_size, self.num_hidden).zero_().cuda())
        else:
            hidden = (weight.new(self.num_layers, batch_size, self.num_hidden).zero_(),
                      weight.new(self.num_layers, batch_size, self.num_hidden).zero_())
        
        return hidden
        

The function `train` takes in an network - `net`, the `data` (our list of notes), trains the network for the specified number of `epochs`, saves each epoch, as well as the best performing (lowest valid) epoch. We split the data into train and valid sets, with a split size based on the parameter `val_split` fraction.  `batch_size,` `seq_length` and `lr` should all be set, though defaults are provided.

In [0]:
def train(net, data, epochs=100, batch_size=10, seq_length=60, lr=0.001, clip=5, val_split=0.1, print_every=20):
    '''          
        net: a NoteRNN network
        data: encoded notes from which to train the network
        epochs: num of epochs to train
        batch_size: batch size
        seq_length: num of notes per batch
        lr: learning rate
        clip: gradient clipping
        val_split: amount of data to reserve for validation split
        print_every: num of steps for printing training and validation loss
    
    '''
        

    net.train()


    # create training and val set
    split = int(len(data)*(1-val_split))
    data, val_data = data[:split], data[split:]
    if(train_on_gpu):
        net.cuda()
    
    counter = 0
    num_notes = net.num_classes
    val_loss_min = np.Inf # track change in validation loss
    for epoch in range(epochs):
        # initialize hidden state
        h = net.init_hidden(batch_size)
        
        for x, y in get_batches(data, batch_size, seq_length):
            counter += 1
            
            # One-hot encode
            x = to_categorical(x, num_notes)
            # make torch tensors
            inputs, targets = torch.from_numpy(x), torch.from_numpy(y)
            # switch to gpu
            if(train_on_gpu):
                inputs, targets = inputs.cuda(), targets.cuda()

            # Creating new variables for the hidden state, otherwise
            # we'd backprop through the entire training history
            h = tuple([each.data for each in h])

            # zero out gradient
            net.zero_grad()
            
            # get output from net
            output, h = net(inputs, h)
            
            # calculate the loss
            loss = criterion(output, targets.view(batch_size*seq_length))
            # backprop
            loss.backward()
            
            # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
            nn.utils.clip_grad_norm_(net.parameters(), clip)
            
            # update weights and biases
            optimizer.step()
                        
            # loss stats
            if counter % print_every == 0:
                print("Calculating loss...")
                # Get validation loss
                val_h = net.init_hidden(batch_size)
                val_losses = []
                # turn dropout off
                net.eval()
                for x, y in get_batches(val_data, batch_size, seq_length):
                    # One-hot encode 
                    x = to_categorical(x, num_notes)
                    # make Torch tensors
                    x, y = torch.from_numpy(x), torch.from_numpy(y)
                    
                    # Creating new variables for the hidden state, otherwise
                    # we'd backprop through the entire training history
                    val_h = tuple([each.data for each in val_h])
                    
                    inputs, targets = x, y
                    # move targets/inputs to gpu
                    if(train_on_gpu):
                        inputs, targets = inputs.cuda(), targets.cuda()
                    # get output from net 
                    output, val_h = net(inputs, val_h)
                    # calc loss
                    val_loss = criterion(output, targets.view(batch_size*seq_length))
                    val_losses.append(val_loss.item())
                # turn dropout back on for training
                net.train()
                
                print("Epoch: {}/{}...".format(epoch+1, epochs),
                      "Step: {}...".format(counter),
                      "Loss: {:.4f}...".format(loss.item()),
                      "Val Loss: {:.4f}".format(np.mean(val_losses)))
                
                # save model if validation loss has decreased
                if val_loss.item() <= val_loss_min:
                    print('Saving best performing model...')
                    save_model(best=True)
                    val_loss_min = val_loss.item()

Specify the data directory for the music files:

In [0]:
# directory where we store our music data
data_dir = '/content/drive/My Drive/Colab Notebooks/data/music_data/music/classical_/'

Create the `Dataset`, print the `.num_classes`, and retrieve the `encoded_notes` to be passed into the network.

In [0]:
# create dataset
classical = Dataset(data_dir)
classical.create_dataset()
print("Vocab size: {}".format(classical.num_classes))
encoded_notes = classical.encoded_notes # encode notes


Here we define the network:

In [0]:
# define and print the network
num_hidden=500
num_layers=1

net = NoteRNN(classical.unique_notes, num_hidden, num_layers)
print(net)

Specify the training parameters, `optimizer` and `criterion`

In [0]:
# parameters
batch_size = 10 # increase with larger dataset
seq_length = 100# seq length really matters- it's the amount of context the net receives!
num_epochs = 5
# optimizer and criterion
lr =0.001
optimizer = torch.optim.SGD(net.parameters(),lr = lr, momentum=0.9)
criterion = nn.CrossEntropyLoss()

# Let's train!

In [0]:
# train the model
train(net, encoded_notes, epochs=num_epochs, batch_size=batch_size, seq_length=seq_length, lr=lr, print_every=10)

The function `predict_multi` takes in a trained network, `net`, and return the hidden state and the next 3 predicted notes

In [0]:
# predict_multi takes in a trained network and returns
# hidden state and the predicted next 3 notes
def predict_multi(net, notes, h=None, top_k=None):

        x = []
    # extract notes
        for nt in notes:
          if nt in net.note2int:
            # encode note
            x.append(net.note2int[nt])
          else:
            # add note to note2int dict
            net.note2int[nt.fullName] = len(net.note2int)+1
            x.append(net.note2int[nt])
        # create np array to one-hot encode
        x = np.array(x)  
        x = x.reshape((1, -1)) # reshape to (1,3)
        # one-hot encode
        x = to_categorical(x, len(net.note2int))
        # create torch tensor
        inputs = torch.from_numpy(x)

        if(train_on_gpu):
            inputs = inputs.cuda()
        
        # detach hidden state from history
        h = tuple([each.data for each in h])
        
        # get the output of the model
        out, h = net(inputs, h)

        # get the character probabilities
        probs = F.softmax(out, dim=1).data
        if(train_on_gpu):
            probs = probs.cpu() # move to cpu
        
        # get top predicted notes
        if top_k is None:
            top_ch = np.arange(net.num_classes)
        else:
            preds = []
            top_choices = []
            # for each class probability
            for p in probs:
              # get the top k values
              pred, top_ch = p.topk(top_k)
              # add the predictions for the notes
              preds.append(pred.numpy().squeeze())
              top_choices.append(top_ch.numpy().squeeze())
       
        # select the likely next note with some element of randomness
        pred = []
        for i in range(len(preds)):
          # get the top 3 choices from each pred
          nts = np.random.choice(top_choices[i], 3)
          # convert int encoded note back to note
          pred.append([net.int2note[nt] for nt in nts])
      
        # return the encoded value of the predicted note and the hidden state
        return pred, h

The function `sample`, generates a sample of notes given a trained network, `size` of the sample to be generated, and list of notes which act as the seed, or` prime`

In [0]:
# generate a sample of notes given a list of notes (prime)
def sample(net, size, prime=['C'], top_k=None):
        
    if(train_on_gpu):
        net.cuda()
    else:
        net.cpu()
    # eval mode (don't want dropout on)
    net.eval() 
    
    # retrieve all notes from prime
    notes = [nt for nt in prime]
    h = net.init_hidden(1)
    for nt in notes:        
      # predict next 3 notes for each note or note sequence in prime
      pred_notes, h = predict_multi(net, nt, h, top_k=top_k)
    # for each sequence of predicted notes add it to the list of notes
    for nt in pred_notes:
      notes.append(nt)
    
    # use the output (last 3 predicted notes) to generate new prediction
    for ii in range(size):
      # continue predicting until we reach the desired number of generated notes (size)
        pred_notes, h = predict_multi(net, notes[-3], h, top_k=top_k)
        for nt in pred_notes: # add each prediction to the list of notes
          notes.append(nt)
    # flatten the list to be able to generate a midi file
    flat_notes = []      
    for nts in notes:
      for nt in nts:
        flat_notes.append(nt)

    return flat_notes

The class Song takes in `generated_notes` from our sample function, the `note_dict` from our trained network, a `data_dir` specifying where we'd like to save the MIDI file, and a `file_name`.

In [0]:
class Song:
  '''
  generated_notes: list of notes (str) to be added to the midi file
  note_dict: dict of music21 note objects from the trained network to convert our str notes to note objects
  data_dir: directory to which we'll save the midi file
  file_name: desired filename of the midi file
  mt: to be Miditrack
  d: duration (for midi events)
  '''
  def __init__(self, generated_notes, note_dict, data_dir, file_name):
    self.generated_notes = generated_notes
    self.note_dict = note_dict
    self.data_dir = data_dir
    self.file_name = file_name
    self.mt = None
    self.d = 0

  def create_rest(self, nt):
    # split on the separator, retrieve the last rest
    rest = nt.split('$')[-1] 
    # retrieve the duration of the rest
    duration = self.note_dict[rest].duration.quarterLength
    return int(duration)
  
          
  # end midi track
  def end_track(self):
    # create delta time
    dt = midi.DeltaTime(self.mt)
    dt.time = 0 # end of track, dt=0
    self.mt.events.append(dt)
    # create end of track event
    me = midi.MidiEvent(self.mt)
    me.type = "END_OF_TRACK"
    me.channel = 1 # specify channel
    me.data =''  # must set data to empty string
    self.mt.events.append(me)
   
          
    # create a midi file using the generated notes
  def create_song(self):
    # initialize midi track
    self.mt = midi.MidiTrack(1)
    
    # where to save the file
    file_path = self.data_dir+self.file_name
    
    for nt in self.generated_notes:
      
      velocity = 0
      pitch_ = 0
      duration = 0
      
      # takes in a str and creates a music21 note object
      if "Note" in nt:
        note_ = self.note_dict[nt] # retrieve note object
        self.d = int(note_.duration.quarterLength) # retrieve duration
        # convert note to midi event
        eventList = midi.translate.noteToMidiEvents(note_, includeDeltaTime=True)
        for event in eventList:
          self.mt.events.append(event) # add midi event to the midi track
        
      elif "Rest" in nt: 
        # create rests
        self.d = self.create_rest(nt)
        # create delta time events 
        dt1 = midi.DeltaTime(self.mt) 
        me1 = midi.MidiEvent(self.mt)
        dt1.time = 0 # start time = 0
        me1.type="DeltaTime"
        self.mt.events.append(dt1) # add dt event to mt
        dt2 = midi.DeltaTime(self.mt)
        me2 = midi.MidiEvent(self.mt)
        dt2.time = self.d # duration of delta time event
        me2.type="DeltaTime"
        self.mt.events.append(dt2) # add dt event to mt

      else: 
        # get chord
        c = self.note_dict[nt] # retrieve music21 chord object
        self.d = int(c.duration.quarterLength) # retrieve duration
        # create midi event
        eventList = midi.translate.chordToMidiEvents(c, includeDeltaTime=True)
        #add to midi track events
        for event in eventList:
          self.mt.events.append(event)

                   
    # create event to end track
    self.end_track()
     # update events so they are all on this track
    self.mt.updateEvents()
    # create midi file to write
    mf = midi.MidiFile()
    mf.ticksPerQuarterNote = 424 #1024 default experiment with different timing (higher=more notes)
    mf.tracks.append(self.mt)
    # write midi file
    print("Writing MIDI track")
    mf.open(file_path, 'wb')
    mf.write()
    mf.close()


# Let's generate some notes!

In [0]:
# generate notes
generated_notes = sample(net, 5, prime=[['F-sharp in octave 4 Eighth Triplet (1/3 QL) Note',
 'F-sharp in octave 3 16th Note',
]], top_k = 10)


Lastly, we'll create a `Song`, by passing in the `generated_notes`, `.music21_objects`created from our dataset, `data_dir` where we'll write the MIDI file to, and a `file_name`. Use the `create_song` method to actually create the song.

In [0]:
# create song and save it
song = Song(generated_notes, classical.music21_objects, data_dir, "multinote_generated_classical_9.mid")
song.create_song()