In [None]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
import pickle
import pretty_midi
import pt_util
import os
from os import listdir
from os.path import join, isfile
import threading
import tqdm
import random

print('Version', torch.__version__)
print('CUDA enabled:', torch.cuda.is_available())

In [None]:
BASE_PATH = '/Users/markusschiffer/music-generation'
if not os.path.exists(BASE_PATH):
    os.makedirs(BASE_PATH)
DATA_PATH = 'data\\'
os.chdir(BASE_PATH)

In [None]:
NUM_THREADS = 4

def add_note(ind2note, note2ind, note):
    if note not in note2ind:
        note2ind[note] = len(note2ind)
        ind2note[len(ind2note)] = note
    return note2ind[note]

def add_tokenized(tokenized_list, tokenized):
    tokenized_list.append(tokenized)    

def thread_task(note2ind, ind2note, tokenized_musics, dict_lock, list_lock, workload_list, thread_num, fs):
    """
    task for thread
    """
    for i, midi_file in enumerate(workload_list):
        if i % 25 == 0:
            print(f'thread {thread_num} processed {i} files')
        try:  # Handle exception on malformat MIDI files
            midi_pretty_format = pretty_midi.PrettyMIDI(midi_file)
            piano_midi = midi_pretty_format.instruments[0]  # Get the piano channels
            # piano_roll (each row is a pitch/note, each col is an array [played notes])
            piano_roll = piano_midi.get_piano_roll(fs=fs)
            # can we assume times is always in a descending order?
            index = np.where(piano_roll > 0)
            times = np.unique(index[1])
            # index[0] is notes, index[1] is timeframes
            cur_music = np.zeros(piano_roll.shape[1], dtype=np.int64)
            for time in times:
                index_where = np.where(index[1] == time)
                notes = index[0][index_where]
                notes = tuple(notes)
                if notes in note2ind: # same note at different tf
                    cur_music[time] = note2ind[notes]
                else:
                    dict_lock.acquire()
                    cur_music[time] = add_note(ind2note, note2ind, notes)
                    dict_lock.release()
            list_lock.acquire()
            add_tokenized(tokenized_musics, np.trim_zeros(cur_music))
            list_lock.release()
        except Exception as e:
            # locks should not be aquired if exception were to occur
            print(e)
            print("broken file : {}".format(midi_file))

def prepare_data(data_path, fs, pickle_name, all_maestro=False, remove_empty=False):
    # global vars
    note2ind = {}
    ind2note = {}
    tokenized_musics = []

    # locks
    dict_lock = threading.Lock()
    list_lock = threading.Lock()

    if all_maestro:
        midi_files = []
        dates = ["2004", "2006", "2008", "2009", "2011", "2013", "2014", "2015", "2017", "2018"]
        for date in dates:
            midi_files.extend([join(data_path, date, f) for f in listdir(join(data_path, date))])
    else:
        midi_files = [join(data_path, f) for f in listdir(data_path)]
    k, m = divmod(len(midi_files), NUM_THREADS)
    workloads = list(midi_files[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(NUM_THREADS))
    note2ind = {}
    ind2note = {}
    empty_note = ()
    note2ind[empty_note] = 0
    ind2note[0] = empty_note
    # a list of np arrays, where the size of each np array is the length of the song
    tokenized_musics = []
    
    # creating threads
    threads = [threading.Thread(target=thread_task, args=(note2ind, ind2note, tokenized_musics, dict_lock, list_lock, workloads[j], j, fs)) for j in range(NUM_THREADS)]

    # start threads
    for thread in threads:
        thread.start()

    # wait until threads finish their job
    for thread in threads:
        thread.join()

    tokenized = np.concatenate(tokenized_musics, axis=0)
    train_text = tokenized[:int(0.8 * len(tokenized))]
    test_text = tokenized[int(0.8 * len(tokenized)):]

    pickle.dump({'tokens': train_text, 'ind2note': ind2note, 'note2ind':note2ind}, open(DATA_PATH + f"music_train_{pickle_name}_{fs}.pkl", "wb"))
    pickle.dump({'tokens': test_text, 'ind2note': ind2note, 'note2ind':note2ind}, open(DATA_PATH + f"music_test_{pickle_name}_{fs}.pkl", "wb"))

prepare_data("musics/Schumann", 2, "schumann_2", all_maestro=False)

In [None]:
class Vocabulary(object):
    def __init__(self, data_file):
        with open(data_file, 'rb') as data_file:
            dataset = pickle.load(data_file)
        self.ind2note = dataset['ind2note']
        self.note2ind = dataset['note2ind']

    # Returns a MIDI notes representation of the tokens.
    def array_to_notes(self, arr):
        return [self.ind2note[int(ind.item())] for ind in arr]


    # Returns a torch tensor representing each token in notes.
    def notes_to_array(self, notes):
        return torch.LongTensor([self.note2ind[note] for note in notes])

    # Returns the size of the vocabulary.
    def __len__(self):
        return len(self.note2ind)

In [None]:
class PianoDataset(torch.utils.data.Dataset):
    def __init__(self, data_file, sequence_length, batch_size):
        super(PianoDataset, self).__init__()

        self.sequence_length = sequence_length
        self.batch_size = batch_size
        self.vocab = Vocabulary(data_file)
        self.data = []

        with open(data_file, 'rb') as data_pkl:
            dataset = pickle.load(data_pkl)
        remain = len(dataset['tokens']) % batch_size
        if remain != 0:
          dataset['tokens'] = dataset['tokens'][:-remain]

        chunk_ranges = []
        chunk_size = len(dataset['tokens']) // batch_size
        for i in range(0, len(dataset['tokens']), chunk_size):
            chunk_ranges.append((i, i + chunk_size - 1))
        for i in range(0, chunk_size, self.sequence_length):
            for chunk in chunk_ranges:
                start = i + chunk[0]
                if not start + self.sequence_length + 1 > chunk[1]:
                    self.data.append((torch.LongTensor(dataset['tokens'][start:start+self.sequence_length]), torch.LongTensor(dataset['tokens'][start+1:start+self.sequence_length+1])))
                else:
                    self.data.append((torch.LongTensor(dataset['tokens'][start:chunk[1]]), torch.LongTensor(dataset['tokens'][start+1:chunk[1]+1])))


    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        # Return the data and label for a character sequence as described above.
        # The data and labels should be torch long tensors.
        # You should return a single entry for the batch using the idx to decide which chunk you are 
        # in and how far down in the chunk you are.
        
        return self.data[idx]

    def vocab_size(self):
        return len(self.vocab)

In [None]:
class PianoLSTMNet(nn.Module):
    def __init__(self, vocab_size, feature_size):
        super(PianoLSTMNet, self).__init__()
        self.vocab_size = vocab_size
        self.feature_size = feature_size
        self.encoder = nn.Embedding(self.vocab_size, self.feature_size)
        self.lstm = nn.LSTM(self.feature_size, self.feature_size, batch_first=True, num_layers=2)
        self.decoder = nn.Linear(self.feature_size, self.vocab_size)
        
        # This shares the encoder and decoder weights as described in lecture.
        self.decoder.weight = self.encoder.weight
        self.decoder.bias.data.zero_()
        
        self.best_accuracy = -1
    
    def forward(self, x, hidden_state=None):
        # batch_size = x.shape[0]
        # sequence_length = x.shape[1]

        x = self.encoder(x)
        x, hidden_state = self.lstm(x, hidden_state)
        x = self.decoder(x)
         

        return x, hidden_state

    # This defines the function that gives a probability distribution and implements the temperature computation.
    def inference(self, x, hidden_state=None, temperature=1):
        x = x.view(-1, 1)
        x, hidden_state = self.forward(x, hidden_state)
        x = x.view(1, -1)
        x = x / max(temperature, 1e-20)
        x = F.softmax(x, dim=1)
        return x, hidden_state

    # Predefined loss function
    def loss(self, prediction, label, reduction='mean'):
        loss_val = F.cross_entropy(prediction.view(-1, self.vocab_size), label.view(-1), reduction=reduction)
        return loss_val

    # Saves the current model
    def save_model(self, file_path, num_to_keep=1):
        pt_util.save(self, file_path, num_to_keep)

    # Saves the best model so far
    def save_best_model(self, accuracy, file_path, num_to_keep=1):
        if accuracy > self.best_accuracy:
            self.save_model(file_path, num_to_keep)
            self.best_accuracy = accuracy

    def load_model(self, file_path):
        pt_util.restore(self, file_path)

    def load_last_model(self, dir_path):
        return pt_util.restore_latest(self, dir_path)

In [None]:
def generate_single_note_seed(vocab, seq_len=50):
    seed = [vocab.ind2note[0] for _ in range(seq_len - 1)]
    note = random.choice(list(vocab.note2ind.keys()))
    while note == vocab.ind2note[0]:
        note = random.choice([vocab.note2ind.keys()])
    seed.append(note)
    return seed


def generate_song_start_seed(data_path, fs, seq_len=50):
    if not isfile(data_path):
        raise ValueError('Enter a valid midi file')
    else:
        midi_pretty_format = pretty_midi.PrettyMIDI(data_path)
        piano_midi = midi_pretty_format.instruments[0]  # Get the piano channels
        piano_roll = piano_midi.get_piano_roll(fs=fs)
        cur_music = []
        index = np.where(piano_roll > 0)
        times = np.unique(index[1])
        for time in times:
            index_where = np.where(index[1] == time)
            notes = index[0][index_where]
            notes = tuple(notes)
            cur_music.append(notes)
            if len(cur_music) == seq_len:
                return cur_music
        raise Exception("Shouldn't get here")

def piano_roll_to_pretty_midi(piano_roll, fs, program=0):
    '''Convert a Piano Roll array into a PrettyMidi object
     with a single instrument.
    Parameters
    ----------
    piano_roll : np.ndarray, shape=(128,frames), dtype=int
        Piano roll of one instrument
    fs : int
        Sampling frequency of the columns, i.e. each column is spaced apart
        by ``1./fs`` seconds.
    program : int
        The program number of the instrument.
    Returns
    -------
    midi_object : pretty_midi.PrettyMIDI
        A pretty_midi.PrettyMIDI class instance describing
        the piano roll.
    '''
    notes, _ = piano_roll.shape
    pm = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(program=program)

    # pad 1 column of zeros so we can acknowledge inital and ending events
    piano_roll = np.pad(piano_roll, [(0, 0), (1, 1)], 'constant')

    # use changes in velocities to find note on / note off events
    velocity_changes = np.nonzero(np.diff(piano_roll).T)

    # keep track on velocities and note on times
    prev_velocities = np.zeros(notes, dtype=int)
    note_on_time = np.zeros(notes)

    for time, note in zip(*velocity_changes):
        # use time + 1 because of padding above
        velocity = piano_roll[note, time + 1]
        time = time / fs
        if velocity > 0:
            if prev_velocities[note] == 0:
                note_on_time[note] = time
                prev_velocities[note] = velocity
        else:
            pm_note = pretty_midi.Note(
                velocity=prev_velocities[note],
                pitch=note,
                start=note_on_time[note],
                end=time)
            instrument.notes.append(pm_note)
            prev_velocities[note] = 0
    pm.instruments.append(instrument)
    return pm

In [15]:
TEMPERATURE = 0.9
BEAM_WIDTH = 10

def max_sampling_strategy(sequence_length, model, output, hidden, vocab):
    outputs = []
    for _ in range(sequence_length):
        max_idx = torch.argmax(output)
        outputs.append(max_idx)
        output, hidden = model.inference(max_idx, hidden, TEMPERATURE)
        
    return outputs
    
def sample_sampling_strategy(sequence_length, model, output, hidden, vocab):
    outputs = []
    for _ in range(sequence_length):
        sample_idx = torch.multinomial(output, 1)
        outputs.append(sample_idx)
        output, hidden = model.inference(sample_idx, hidden, TEMPERATURE)

    return outputs

def beam_thread_task(beam, beam_width, model, new_beams, beam_lock):
    sample_idxs = torch.multinomial(beam[1], beam_width, replacement=True)
    for new_sample in torch.flatten(sample_idxs):
        new_output, new_hidden = model.inference(new_sample, beam[2], TEMPERATURE)
        new_outputs = beam[0].copy() # generated texts
        new_outputs.append(new_sample)
        beam_lock.acquire()
        new_beams.append((new_outputs, new_output, new_hidden, 
                            beam[3] + torch.log(new_sample).item()))
        beam_lock.release()

def beam_sampling_strategy(sequence_length, beam_width, model, output, hidden, vocab):
    beams = [([], output, hidden, 0)]
    beam_lock = threading.Lock()
    for _ in range(sequence_length):
        new_beams = []
        threads = [threading.Thread(target=beam_thread_task, args=(beam, beam_width, model, new_beams, beam_lock)) for beam in beams]
        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
        new_beams = sorted(new_beams, key=lambda x: x[3], reverse=True)
        beams = new_beams[:beam_width]

    return beams[0][0]

def generate_language(model, device, seed_notes, vocab, midi_file_name, fs,
                      start_index=49, max_generated=300, sampling_strategy='max', beam_width=BEAM_WIDTH):
    model.eval()

    with torch.no_grad():
        seed_notes_arr = vocab.notes_to_array(seed_notes)

        # Computes the initial hidden state from the prompt (seed words).
        hidden = None
        for ind in seed_notes_arr:
            data = ind.to(device)
            output, hidden = model.inference(data, hidden)
        # test_multinomial(1000, output)
      
        
        if sampling_strategy == 'max':
            outputs = max_sampling_strategy(max_generated, model, output, hidden, vocab)

        elif sampling_strategy == 'sample':
            outputs = sample_sampling_strategy(max_generated, model, output, hidden, vocab)

        elif sampling_strategy == 'beam':
            outputs = beam_sampling_strategy(max_generated, beam_width, model, output, hidden, vocab)


        note_arr = vocab.array_to_notes(torch.cat((seed_notes_arr, torch.Tensor(outputs))))
        array_piano_roll = np.zeros((128, max_generated + 1), dtype=np.int16) # max_generated is len(note_arr)
        for index, note in enumerate(note_arr[start_index:]):
            if note != vocab.ind2note[0]:
                for j in note:
                    array_piano_roll[int(j), index] = 1
        generate_to_midi = piano_roll_to_pretty_midi(array_piano_roll, fs=fs)
        # print("Tempo {}".format(generate_to_midi.estimate_tempo()))
        for note in generate_to_midi.instruments[0].notes:
            note.velocity = 100
        generate_to_midi.write(midi_file_name)

In [None]:
import tqdm
def repackage_hidden(h):
    """Wraps hidden states in new Tensors, to detach them from their history."""
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)

def train(model, device, optimizer, train_loader, lr, epoch, log_interval):
    model.train()
    losses = []
    hidden = None
    for batch_idx, (data, label) in enumerate(tqdm.tqdm(train_loader)):
        data, label = data.to(device), label.to(device)
        # Separates the hidden state across batches. 
        # Otherwise the backward would try to go all the way to the beginning every time.
        if hidden is not None:
            hidden = repackage_hidden(hidden)
        optimizer.zero_grad()
        output, hidden = model(data)
        pred = output.max(-1)[1]
        loss = model.loss(output, label)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    return np.mean(losses)


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        hidden = None
        for batch_idx, (data, label) in enumerate(test_loader):
            data, label = data.to(device), label.to(device)
            output, hidden = model(data, hidden)
            test_loss += model.loss(output, label, reduction='mean').item()
            pred = output.max(-1)[1]
            correct_mask = pred.eq(label.view_as(pred))
            num_correct = correct_mask.sum().item()
            correct += num_correct          

    test_loss /= len(test_loader)
    test_accuracy = 100. * correct / (len(test_loader.dataset) * test_loader.dataset.sequence_length)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset) * test_loader.dataset.sequence_length,
        100. * correct / (len(test_loader.dataset) * test_loader.dataset.sequence_length)))
    return test_loss, test_accuracy

In [None]:
SEED_SIZE = 50
SEQUENCE_LENGTH = 100
BATCH_SIZE = 256
FEATURE_SIZE = 512
TEST_BATCH_SIZE = 256
EPOCHS = 200
LEARNING_RATE = 0.002
WEIGHT_DECAY = 0.0005
USE_CUDA = True
PRINT_INTERVAL = 10
LOG_PATH = DATA_PATH + 'logs/log.pkl'
DATASET_NAME = "schumann_2"

def main():
    data_train = PianoDataset(DATA_PATH + f"music_train_{DATASET_NAME}.pkl", SEQUENCE_LENGTH, BATCH_SIZE)
    data_test = PianoDataset(DATA_PATH + f"music_test_{DATASET_NAME}.pkl", SEQUENCE_LENGTH, TEST_BATCH_SIZE)
    vocab = data_train.vocab

    use_cuda = USE_CUDA and torch.cuda.is_available()

    device = torch.device("cuda" if use_cuda else "cpu")
    print('Using device', device)
    # import multiprocessing
    # num_workers = multiprocessing.cpu_count()
    num_workers = 0
    print('num workers:', num_workers)

    kwargs = {'num_workers': num_workers,
              'pin_memory': True} if use_cuda else {}

    train_loader = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE,
                                               shuffle=False, **kwargs)
    test_loader = torch.utils.data.DataLoader(data_test, batch_size=TEST_BATCH_SIZE,
                                              shuffle=False, **kwargs)

    model = PianoLSTMNet(data_train.vocab_size(), FEATURE_SIZE).to(device)

    # Adam is an optimizer like SGD but a bit fancier. It tends to work faster and better than SGD.
    # We will talk more about different optimization methods in class.
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    start_epoch = model.load_last_model(DATA_PATH + 'checkpoints')

    train_losses, test_losses, test_accuracies = pt_util.read_log(LOG_PATH, ([], [], []))
    test_loss, test_accuracy = test(model, device, test_loader)

    test_losses.append((start_epoch, test_loss))
    test_accuracies.append((start_epoch, test_accuracy))

    try:
        for epoch in range(start_epoch, EPOCHS + 1):
            lr = LEARNING_RATE * np.power(0.25, (int(epoch / 6)))
            train_loss = train(model, device, optimizer, train_loader, lr, epoch, PRINT_INTERVAL)
            test_loss, test_accuracy = test(model, device, test_loader)
            train_losses.append((epoch, train_loss))
            test_losses.append((epoch, test_loss))
            test_accuracies.append((epoch, test_accuracy))
            pt_util.write_log(LOG_PATH, (train_losses, test_losses, test_accuracies))
            model.save_best_model(test_accuracy, DATA_PATH + 'checkpoints/%03d.pt' % epoch)

    except KeyboardInterrupt as ke:
        print('Interrupted')
    except:
        import traceback
        traceback.print_exc()
    finally:
        print('Saving final model')
        model.save_model(DATA_PATH + 'checkpoints/%03d.pt' % epoch, 0)
        ep, val = zip(*train_losses)
        pt_util.plot(ep, val, 'Train loss', 'Epoch', 'Error')
        ep, val = zip(*train_losses)
        pt_util.plot(ep, np.exp(val), 'Train Perplexity', 'Epoch', 'Error')
        ep, val = zip(*test_losses)
        pt_util.plot(ep, val, 'Test loss', 'Epoch', 'Error')
        ep, val = zip(*test_losses)
        pt_util.plot(ep, np.exp(val), 'Test Perplexity', 'Epoch', 'Error')
        ep, val = zip(*test_accuracies)
        pt_util.plot(ep, val, 'Test accuracy', 'Epoch', 'Error')
        return model, vocab, device

final_model, vocab, device = main()

In [None]:
TITLE = "schumann_lstm_dataset_fsin_2_fsout_8_temp_0p9"

def eval_final_model(model, vocab, device, title, fs):
    seed_notes = generate_single_note_seed(vocab, SEED_SIZE)
    
    generate_language(model, device, seed_notes, vocab, f"max_sample_{title}.mid", fs=fs)
    print('generated with max')

    for ii in range(5):
        generate_language(model, device, seed_notes, vocab, f"sampling_sample_{ii}_{title}.mid", sampling_strategy='sample', fs=fs)
        print('generated with sample')

    for ii in range(5):
        generate_language(model, device, seed_notes, vocab, f"beam_sample_{ii}_{title}.mid", sampling_strategy='beam', fs=fs)
        print('generated with beam')

eval_final_model(final_model, vocab, device, TITLE, 8)

In [13]:
last_model = PianoLSTMNet(len(vocab), FEATURE_SIZE).to(device)
last_model.load_last_model(DATA_PATH + 'checkpoints')
eval_final_model(last_model, vocab, device, "schumann_latest_lstm_dataset_fsin_2_fsout_8_temp_0p9", 8)

Restoring:
encoder.weight -> 	torch.Size([42436, 512]) = 86MB
lstm.weight_ih_l0 -> 	torch.Size([2048, 512]) = 4MB
lstm.weight_hh_l0 -> 	torch.Size([2048, 512]) = 4MB
lstm.bias_ih_l0 -> 	torch.Size([2048]) = 0MB
lstm.bias_hh_l0 -> 	torch.Size([2048]) = 0MB
lstm.weight_ih_l1 -> 	torch.Size([2048, 512]) = 4MB
lstm.weight_hh_l1 -> 	torch.Size([2048, 512]) = 4MB
lstm.bias_ih_l1 -> 	torch.Size([2048]) = 0MB
lstm.bias_hh_l1 -> 	torch.Size([2048]) = 0MB
decoder.weight -> 	torch.Size([42436, 512]) = 86MB
decoder.bias -> 	torch.Size([42436]) = 0MB

Restored all variables
No new variables
Restored data\checkpoints\088.pt
generated with max
generated with sample
generated with sample
generated with sample
generated with sample
generated with sample
generated with beam
generated with beam
generated with beam
generated with beam
generated with beam


In [16]:
eval_final_model(last_model, vocab, device, "maestro_latest_dataset_fsin_2_fsout_2_temp_0p9", 2)

generated with max
generated with sample
generated with sample
generated with sample
generated with sample
generated with sample
generated with beam
generated with beam
generated with beam
generated with beam
generated with beam


In [17]:
eval_final_model(final_model, vocab, device, "maestro_dataset_fsin_2_fsout_2_temp_0p9", 2)

generated with max
generated with sample
generated with sample
generated with sample
generated with sample
generated with sample
generated with beam
generated with beam
generated with beam
generated with beam
generated with beam
