# Overfitting Example
We try to overfit a single simple song with a encoder-decoder LSTM.
This time we use PyTorch and embeddings.
The code is inspired by this tutorial: [https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html](https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html)

In [1]:
from music21 import stream, note, metadata

def get():
    piece = stream.Score()
    p1 = stream.Part()
    p1.id = 'part1'

    notes = [note.Note('C4', type='quarter'),
             note.Note('D4', type='quarter'),
             note.Note('E4', type='quarter'),
             note.Note('F4', type='quarter'),
             note.Note('G4', type='half'),
             note.Note('G4', type='half'),
    
             note.Note('A4', type='quarter'),
             note.Note('A4', type='quarter'),
             note.Note('A4', type='quarter'),
             note.Note('A4', type='quarter'),
             note.Note('G4', type='half'),

             note.Note('A4', type='quarter'),
             note.Note('A4', type='quarter'),
             note.Note('A4', type='quarter'),
             note.Note('A4', type='quarter'),
             note.Note('G4', type='half'),

             note.Note('F4', type='quarter'),
             note.Note('F4', type='quarter'),
             note.Note('F4', type='quarter'),
             note.Note('F4', type='quarter'),
             note.Note('E4', type='half'),
             note.Note('E4', type='half'),

             note.Note('D4', type='quarter'),
             note.Note('D4', type='quarter'),
             note.Note('D4', type='quarter'),
             note.Note('D4', type='quarter'),
             note.Note('C4', type='half')
            ]
    p1.append(notes)
    piece.insert(0, metadata.Metadata())
    piece.metadata.title = 'Alle meine Entchen'
    piece.insert(0, p1)
    return piece, notes

In [2]:
piece, notes = get()
piece.show('midi')
#piece.show() # doesn't work inside the notebook for me

![images/overfitting_piece.PNG](images/overfitting_piece.PNG)

# Part 1: Encoding & Data Preparation

We use:
- 128 midi notes
- 128 additional midi notes. This represents the midi notes which are Tied to the previous note.
- 258 additional symbols (Start, Stop)

Therefore we encode our notes as 131-dimensional vector.

* The encoder get's hald of the song as input
* the decoder has to produce the missing half

In [3]:
import music21
from music21 import pitch, interval, stream
import numpy as np

In [4]:
def getTotalTokens():
    return 128*2+ 2  # 128 midi notes + Start + Stop

def getStartIndex():
    return 256

def getStopIndex():
    return 257

In [5]:
def encodeNoteList(notes, delta):
    sequence = []

    for n in notes:
        if (n.isNote):
            sequence.append(n.pitch.midi)
            ticksOn = int(n.duration.quarterLength / delta)
            #print("ticksOn:", ticksOn)
            for i in range(0, ticksOn-1):
                sequence.append(n.pitch.midi + 128)

        if (n.isChord):
            raise NotImplementedError

    return sequence


def split(notes, splitRatio=0.5):
    splitIndex = int(len(notes)*splitRatio)
    x = notes[0:splitIndex]
    y = notes[splitIndex:] + [getStopIndex()]
    return x, y

In [6]:
input = encodeNoteList(notes, delta=1)
input, target = split(input, splitRatio=0.49)
pairs = [(input,target)]
print(pairs)

[([60, 62, 64, 65, 67, 195, 67, 195, 69, 69, 69, 69, 67, 195, 69, 69], [69, 69, 67, 195, 65, 65, 65, 65, 64, 192, 64, 192, 62, 62, 62, 62, 60, 188, 257])]


# Part 2: Model definition & Training
We use a encoder-decoder model from here: [https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html](https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html) 

In [8]:
import random
import torch
from torch import optim, nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

teacher_forcing_ratio = 1.0
MAX_LENGTH = 10

In [9]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)

        output = embedded
        output, hidden = self.gru(output, hidden)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [10]:
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden, quatsch):
        output = self.embedding(input).view(1, 1, -1)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden, quatsch

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [11]:
hidden_size = 64
encoder = EncoderRNN(getTotalTokens(), hidden_size).to(device)
decoder = DecoderRNN(hidden_size, getTotalTokens()).to(device)

In [12]:
def train(input, target, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):
    encoder_hidden = encoder.initHidden()

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_length = len(input)
    target_length = len(target)

    input = torch.tensor(input)
    target = torch.tensor(target).view(-1, 1)

    encoder_outputs = torch.zeros(target_length+max_length, encoder.hidden_size, device=device)

    loss = 0

    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(input[ei], encoder_hidden)
        encoder_outputs[ei] = encoder_output[0, 0]

    decoder_input = torch.tensor([[getStartIndex()]], device=device)

    decoder_hidden = encoder_hidden

    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    if use_teacher_forcing:
        # Teacher forcing: Feed the target as the next input
        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)
            loss += criterion(decoder_output, target[di])
            decoder_input = target[di]  # Teacher forcing

    else:
        raise NotImplementedError
        # Without teacher forcing: use its own predictions as the next input
        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(
                decoder_input, decoder_hidden, encoder_outputs)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach()  # detach from history as input

            loss += criterion(decoder_output, target[di])
            if decoder_input.item() == getStopIndex():
                break

    loss.backward()

    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / target_length

In [13]:
def trainIters(pairs, encoder, decoder, epochs, learning_rate=0.01):

    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)

    criterion = nn.NLLLoss()

    for iter in range(0, epochs):
        for example in range(0, len(pairs)):
            training_pair = pairs[example]
            input = training_pair[0]
            target = training_pair[1]

            loss = train(input, target,encoder,decoder, encoder_optimizer, decoder_optimizer,criterion)

        print("Epoch", iter+1, " finished. Loss: ", loss)

In [14]:
trainIters(pairs, encoder, decoder, epochs=100)

Epoch 1  finished. Loss:  5.579245316354852
Epoch 2  finished. Loss:  5.283050537109375
Epoch 3  finished. Loss:  4.988585622687089
Epoch 4  finished. Loss:  4.669613486842105
Epoch 5  finished. Loss:  4.309282804790296
Epoch 6  finished. Loss:  3.9054758172286186
Epoch 7  finished. Loss:  3.4812493575246712
Epoch 8  finished. Loss:  3.0921381900185034
Epoch 9  finished. Loss:  2.7919518320184005
Epoch 10  finished. Loss:  2.5700984754060445
Epoch 11  finished. Loss:  2.3858753003572164
Epoch 12  finished. Loss:  2.222306100945724
Epoch 13  finished. Loss:  2.0761546084755347
Epoch 14  finished. Loss:  1.9452940288342928
Epoch 15  finished. Loss:  1.8270845915141858
Epoch 16  finished. Loss:  1.719544862446032
Epoch 17  finished. Loss:  1.6216381474545127
Epoch 18  finished. Loss:  1.5327499791195518
Epoch 19  finished. Loss:  1.4522033490632709
Epoch 20  finished. Loss:  1.3791188691791736
Epoch 21  finished. Loss:  1.3125279075221012
Epoch 22  finished. Loss:  1.2515234696237665
Epoc

# Part 3: Inference
* Use the trained model and predict the second ahlf of the training data
* Represent the generated as a music21 piece in order to display and play it

In [15]:
with torch.no_grad():
    input_tensor = torch.tensor(input)
    input_length = input_tensor.size()[0]
    encoder_hidden = encoder.initHidden()

    max_length = 25

    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)

    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(input_tensor[ei],
                                                 encoder_hidden)
        encoder_outputs[ei] += encoder_output[0, 0]

    decoder_input = torch.tensor([[getStartIndex()]], device=device)  # SOS

    decoder_hidden = encoder_hidden

    decoded_words = []
    decoder_attentions = torch.zeros(max_length, max_length)

    for di in range(max_length):
        decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)
        #decoder_attentions[di] = decoder_attention.data
        topv, topi = decoder_output.data.topk(1)
        if topi.item() == getStopIndex():
            decoded_words.append(getStopIndex())
            break
        else:
            decoded_words.append(topi.item())

        decoder_input = topi.squeeze().detach()

In [17]:
def decodeSequence(seq, input=None, delta=1):
    notes = []

    for i in range(0, len(seq)):

        index = seq[i]

        if index == getStopIndex():
            break

        if i == 0 and index <= 128:
            n = music21.note.Note()
            n.pitch.midi = index
            notes.append(n)
        elif i == 0:
            print(index)
            raise NotImplementedError

        else:
            previousNote = notes[-1].pitch.midi

            if index <= 128:
                n = music21.note.Note()
                n.pitch.midi = index
                notes.append(n)
            elif index < 128 * 2 and index - 128 == previousNote:
                notes[-1].quarterLength += delta
            else:
                raise NotImplementedError


    if input is not None:
        print("reiin", input)
        notes = input + [music21.note.Rest(type='half')] + notes

    piece = music21.stream.Score()
    p1 = music21.stream.Part()
    p1.id = 'part1'

    p1.append(notes)
    piece.insert(0, music21.metadata.Metadata())
    piece.metadata.title = 'Title'
    piece.insert(0, p1)
    return piece, notes

In [19]:
inputNotes = notes[:int(len(notes)*0.49)]

print("i:", input)
print("d:", decoded_words)
print("t:", target)


i: [60, 62, 64, 65, 67, 195, 67, 195, 69, 69, 69, 69, 67, 195, 69, 69]
d: [69, 69, 67, 195, 65, 65, 65, 65, 64, 192, 64, 192, 62, 62, 62, 62, 62, 60, 188, 257]
t: [69, 69, 67, 195, 65, 65, 65, 65, 64, 192, 64, 192, 62, 62, 62, 62, 60, 188, 257]


In [20]:
p, _ = decodeSequence(decoded_words, inputNotes, delta=1)
p.show('midi')
#p.show()

reiin [<music21.note.Note C>, <music21.note.Note D>, <music21.note.Note E>, <music21.note.Note F>, <music21.note.Note G>, <music21.note.Note G>, <music21.note.Note A>, <music21.note.Note A>, <music21.note.Note A>, <music21.note.Note A>, <music21.note.Note G>, <music21.note.Note A>, <music21.note.Note A>]


![images/overfitting_pytorch.PNG](images/overfitting_pytorch.PNG)

The decoder perfectly reproduced the piece!

## Limitations & future work
- [ ] No chords supported only single notes
- [ ] No Ties between different notes supported
- [ ] Use Attention