<a href="https://colab.research.google.com/github/edufantini/music-gen/blob/main/src/MusicGeneratorLSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Music Generator with LSTM

##About

Music resamples language as a temporal sequence of articulated sounds. They say something, often something human.

Although, there are crucial differences between language and music. We can still describe it as a sequence of symbols in the simplest form of understanding. Translating something complex into something simpler, but usable by computational models.

Thus, the objective of this project is to establish a communication between the human, that understands music in the most intense way that the brain can interpret through information, and the machine.

We'll create a model that can generate music based on the input information, i.e., generate a sequence of sounds which are related in some way with the sounds passed as input.

We'll use Natural Language Processing (NLP) methods, observing the music as it were a language, abstracting it. Doing this, the machine can recognize and process similar data.

On the first step, we'll use text generation techniques, using Recurrent Neural Networks (RNNs) and Long-Short Term Memories (LSTMs). With the effectiveness of the training, even if it's reasonable, we'll perform the same implementation using specific methods such as Attention.



## Imports

In [44]:
# Basic libraries
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn.functional as F

# Preprocessing data libraries
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

# Model libraries
import torch
import torch.nn as nn
import torch.optim as optim

# Data visualization
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm #for loading bars

In [2]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In [3]:
!git clone https://github.com/edufantini/music-gen.git

Cloning into 'music-gen'...
remote: Enumerating objects: 19522, done.[K
remote: Counting objects: 100% (54/54), done.[K
remote: Compressing objects: 100% (39/39), done.[K
remote: Total 19522 (delta 19), reused 43 (delta 14), pack-reused 19468[K
Receiving objects: 100% (19522/19522), 221.29 MiB | 23.34 MiB/s, done.
Resolving deltas: 100% (89/89), done.


In [4]:
from music_gen.src.GetData import *

## Dataset

In [5]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [6]:
import os
os.environ['KAGGLE_CONFIG_DIR'] = "/content/gdrive/My Drive/Kaggle"

In [7]:
%cd /content/gdrive/My Drive/Kaggle

/content/gdrive/My Drive/Kaggle


In [8]:
!kaggle datasets download -d edufantini/songs-in-midi

Downloading songs-in-midi.zip to /content/gdrive/My Drive/Kaggle
 99% 224M/226M [00:02<00:00, 88.6MB/s]
100% 226M/226M [00:02<00:00, 82.7MB/s]


In [9]:
!ls

clean_midi  dataset  kaggle.json  music-gen  songs-in-midi.zip


In [10]:
!unzip \*.zip  && rm *.zip

Archive:  songs-in-midi.zip
replace clean_midi/.38 Special/Caught Up In You.mid? [y]es, [n]o, [A]ll, [N]one, [r]ename: N


In [11]:
path = '/content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/'

dataset = []

for filename in os.listdir(path):
  if filename.endswith("mid"): 
    # Your code comes here such as 
    print(path + filename)
    #if filename is not 'Back_In_Black.mid':
    data = encode_data(path+filename, 32)
    dataset.append(data)

/content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/Dirty_Deeds_Done_Dirt_Cheap.mid
Processing file /content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/Dirty_Deeds_Done_Dirt_Cheap.mid
Processing part 1/3


Converting measures from part 1: 100%|###############################################################################| 58/58 [00:00<00:00, 272.90it/s]


Processing part 2/3


Converting measures from part 2: 100%|##############################################################################| 63/63 [00:00<00:00, 2566.00it/s]



Processing part 3/3


Converting measures from part 3: 100%|#############################################################################| 117/117 [00:00<00:00, 153.04it/s]


/content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/For_Those_About_To_Rock_We_Salute_You_.mid
Processing file /content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/For_Those_About_To_Rock_We_Salute_You_.mid
Processing part 1/9


Converting measures from part 1: 100%|###############################################################################| 94/94 [00:00<00:00, 547.59it/s]



Processing part 2/9


Converting measures from part 2:  76%|###########################################################2                  | 95/125 [00:00<00:00, 493.09it/s]




Converting measures from part 2: 100%|#############################################################################| 125/125 [00:00<00:00, 400.26it/s]


Processing part 3/9


Converting measures from part 3:  65%|##################################################5                           | 81/125 [00:00<00:00, 414.20it/s]




Converting measures from part 3: 100%|#############################################################################| 125/125 [00:00<00:00, 423.48it/s]


Processing part 4/9


Converting measures from part 4: 100%|#############################################################################| 157/157 [00:00<00:00, 237.21it/s]


Processing part 5/9


Converting measures from part 5: 100%|###############################################################################| 10/10 [00:00<00:00, 746.73it/s]



Processing part 6/9


Converting measures from part 6: 100%|##############################################################################| 89/89 [00:00<00:00, 2102.15it/s]



Processing part 7/9


Converting measures from part 7:  63%|#################################################2                            | 79/125 [00:00<00:00, 404.91it/s]




Converting measures from part 7: 100%|#############################################################################| 125/125 [00:00<00:00, 424.98it/s]


Processing part 8/9
/content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/Highway_To_Hell.1.mid
Processing file /content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/Highway_To_Hell.1.mid
Processing part 1/4


Converting measures from part 1:  82%|################################################################9              | 74/90 [00:00<00:00, 378.10it/s]




Converting measures from part 1: 100%|###############################################################################| 90/90 [00:00<00:00, 351.82it/s]


Processing part 2/4


Converting measures from part 2: 100%|###############################################################################| 57/57 [00:00<00:00, 822.04it/s]



Processing part 3/4


Converting measures from part 3: 100%|##############################################################################| 56/56 [00:00<00:00, 1224.42it/s]


Processing part 4/4


Converting measures from part 4: 100%|###############################################################################| 62/62 [00:00<00:00, 544.35it/s]


Harmonica
/content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/R.I.P._Rock_in_Peace_.mid
Processing file /content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/R.I.P._Rock_in_Peace_.mid
Processing part 1/7


Converting measures from part 1: 100%|###############################################################################| 47/47 [00:00<00:00, 385.94it/s]


Processing part 2/7


Converting measures from part 2: 100%|###############################################################################| 46/46 [00:00<00:00, 397.22it/s]


Processing part 3/7


Converting measures from part 3: 100%|##############################################################################| 19/19 [00:00<00:00, 4015.51it/s]


Processing part 4/7


Converting measures from part 4: 100%|###############################################################################| 39/39 [00:00<00:00, 477.86it/s]


Processing part 5/7


Converting measures from part 5: 100%|###############################################################################| 39/39 [00:00<00:00, 444.62it/s]


Processing part 6/7


Converting measures from part 6: 100%|###############################################################################| 46/46 [00:00<00:00, 297.38it/s]


Processing part 7/7


Converting measures from part 7: 100%|##############################################################################| 16/16 [00:00<00:00, 1738.21it/s]


/content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/You_Shook_Me_All_Night_Long.mid
Processing file /content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/You_Shook_Me_All_Night_Long.mid
Processing part 1/6


Converting measures from part 1:  66%|###################################################7                          | 71/107 [00:00<00:00, 357.58it/s]




Converting measures from part 1: 100%|#############################################################################| 107/107 [00:00<00:00, 287.31it/s]
Converting measures from part 2: 100%|#################################################################################| 1/1 [00:00<00:00, 959.58it/s]


Processing part 2/6
Processing part 3/6


Converting measures from part 3:  68%|#####################################################                         | 70/103 [00:00<00:00, 352.79it/s]




Converting measures from part 3: 100%|#############################################################################| 103/103 [00:00<00:00, 281.19it/s]


Processing part 4/6


Converting measures from part 4: 100%|###############################################################################| 93/93 [00:00<00:00, 470.82it/s]


Processing part 5/6


Converting measures from part 5: 100%|#################################################################################| 3/3 [00:00<00:00, 772.57it/s]


Processing part 6/6


Converting measures from part 6: 100%|#################################################################################| 2/2 [00:00<00:00, 825.65it/s]


/content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/Whole_Lotta_Rosie.mid
Processing file /content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/Whole_Lotta_Rosie.mid
Processing part 1/5


Converting measures from part 1:  39%|##############################5                                               | 63/161 [00:00<00:00, 318.62it/s]




Converting measures from part 1: 100%|#############################################################################| 161/161 [00:00<00:00, 304.26it/s]


Processing part 2/5


Converting measures from part 2: 100%|#############################################################################| 160/160 [00:01<00:00, 151.51it/s]


Processing part 3/5


Converting measures from part 3:  37%|#############################                                                 | 60/161 [00:00<00:00, 290.78it/s]




Converting measures from part 3: 100%|#############################################################################| 161/161 [00:00<00:00, 228.53it/s]


Processing part 4/5


Converting measures from part 4: 100%|###############################################################################| 71/71 [00:00<00:00, 547.20it/s]


Processing part 5/5


Converting measures from part 5: 100%|##############################################################################| 98/98 [00:00<00:00, 1543.94it/s]



/content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/TNT.1.mid
Processing file /content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/TNT.1.mid
Processing part 1/5


[None, None, None, <music21.beam.Beams <music21.beam.Beam 1/start>/<music21.beam.Beam 2/partial/right>>, <music21.beam.Beams <music21.beam.Beam 1/stop>/<music21.beam.Beam 2/stop>>]
[None, None, None, <music21.beam.Beams <music21.beam.Beam 1/start>/<music21.beam.Beam 2/partial/right>>, <music21.beam.Beams <music21.beam.Beam 1/stop>/<music21.beam.Beam 2/stop>>]
[None, None, None, <music21.beam.Beams <music21.beam.Beam 1/start>/<music21.beam.Beam 2/partial/right>>, <music21.beam.Beams <music21.beam.Beam 1/stop>/<music21.beam.Beam 2/stop>>]
[None, None, None, <music21.beam.Beams <music21.beam.Beam 1/start>/<music21.beam.Beam 2/partial/right>>, <music21.beam.Beams <music21.beam.Beam 1/stop>/<music21.beam.Beam 2/stop>>]
Converting measures from part 1: 100%|#############################################################################| 107/107 [00:00<00:00, 172.20it/s]


Processing part 2/5


[<music21.beam.Beams <music21.beam.Beam 1/start>>, <music21.beam.Beams <music21.beam.Beam 1/stop>>, <music21.beam.Beams <music21.beam.Beam 1/start>/<music21.beam.Beam 2/partial/right>>, <music21.beam.Beams <music21.beam.Beam 1/stop>/<music21.beam.Beam 2/stop>>]
[<music21.beam.Beams <music21.beam.Beam 1/start>>, <music21.beam.Beams <music21.beam.Beam 1/stop>>, <music21.beam.Beams <music21.beam.Beam 1/start>/<music21.beam.Beam 2/partial/right>>, <music21.beam.Beams <music21.beam.Beam 1/stop>/<music21.beam.Beam 2/stop>>]
[<music21.beam.Beams <music21.beam.Beam 1/start>>, <music21.beam.Beams <music21.beam.Beam 1/stop>>, <music21.beam.Beams <music21.beam.Beam 1/start>/<music21.beam.Beam 2/partial/right>>, <music21.beam.Beams <music21.beam.Beam 1/stop>/<music21.beam.Beam 2/stop>>]
[<music21.beam.Beams <music21.beam.Beam 1/start>/<music21.beam.Beam 2/start>>, <music21.beam.Beams <music21.beam.Beam 1/stop>/<music21.beam.Beam 2/stop>>, <music21.beam.Beams <music21.beam.Beam 1/start>/<music21.be

Processing part 3/5


[None, None, None, <music21.beam.Beams <music21.beam.Beam 1/start>/<music21.beam.Beam 2/partial/right>>, <music21.beam.Beams <music21.beam.Beam 1/stop>/<music21.beam.Beam 2/stop>>]
[None, None, None, <music21.beam.Beams <music21.beam.Beam 1/start>/<music21.beam.Beam 2/partial/right>>, <music21.beam.Beams <music21.beam.Beam 1/stop>/<music21.beam.Beam 2/stop>>]
[None, None, None, <music21.beam.Beams <music21.beam.Beam 1/start>/<music21.beam.Beam 2/partial/right>>, <music21.beam.Beams <music21.beam.Beam 1/stop>/<music21.beam.Beam 2/stop>>]
[None, None, None, <music21.beam.Beams <music21.beam.Beam 1/start>/<music21.beam.Beam 2/partial/right>>, <music21.beam.Beams <music21.beam.Beam 1/stop>/<music21.beam.Beam 2/stop>>]
Converting measures from part 3: 100%|###############################################################################| 98/98 [00:00<00:00, 215.27it/s]


Processing part 4/5


Converting measures from part 4: 100%|##############################################################################| 25/25 [00:00<00:00, 2273.73it/s]


Processing part 5/5


Converting measures from part 5: 100%|###############################################################################| 86/86 [00:00<00:00, 473.22it/s]


/content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/Highway_To_Hell.mid
Processing file /content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/Highway_To_Hell.mid
Processing part 1/4


Converting measures from part 1:  78%|#############################################################4                 | 70/90 [00:00<00:00, 352.29it/s]




Converting measures from part 1: 100%|###############################################################################| 90/90 [00:00<00:00, 328.39it/s]


Processing part 2/4


Converting measures from part 2: 100%|###############################################################################| 57/57 [00:00<00:00, 838.99it/s]



Processing part 3/4


Converting measures from part 3: 100%|##############################################################################| 56/56 [00:00<00:00, 1240.32it/s]


Processing part 4/4


Converting measures from part 4: 100%|###############################################################################| 62/62 [00:00<00:00, 507.12it/s]



/content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/TNT.mid
Processing file /content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/TNT.mid
Processing part 1/1
/content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/Thunderstruck.1.mid
Processing file /content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/Thunderstruck.1.mid
Processing part 1/1


[<music21.beam.Beams <music21.beam.Beam 1/start>/<music21.beam.Beam 2/partial/right>/<music21.beam.Beam 3/partial/right>>, <music21.beam.Beams <music21.beam.Beam 1/continue>>, <music21.beam.Beams <music21.beam.Beam 1/stop>>, <music21.beam.Beams <music21.beam.Beam 1/start>>, <music21.beam.Beams <music21.beam.Beam 1/stop>>, <music21.beam.Beams <music21.beam.Beam 1/start>>, <music21.beam.Beams <music21.beam.Beam 1/stop>>, <music21.beam.Beams <music21.beam.Beam 1/start>/<music21.beam.Beam 2/partial/right>>, <music21.beam.Beams <music21.beam.Beam 1/stop>/<music21.beam.Beam 2/stop>>]
[<music21.beam.Beams <music21.beam.Beam 1/start>/<music21.beam.Beam 2/partial/right>/<music21.beam.Beam 3/partial/right>>, <music21.beam.Beams <music21.beam.Beam 1/continue>>, <music21.beam.Beams <music21.beam.Beam 1/stop>>, <music21.beam.Beams <music21.beam.Beam 1/start>>, <music21.beam.Beams <music21.beam.Beam 1/stop>>, <music21.beam.Beams <music21.beam.Beam 1/start>>, <music21.beam.Beams <music21.beam.Beam 1/

/content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/Who_Made_Who.mid
Processing file /content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/Who_Made_Who.mid
Processing part 1/4


Converting measures from part 1: 100%|##############################################################################| 44/44 [00:00<00:00, 1564.16it/s]



Processing part 2/4


Converting measures from part 2:  67%|####################################################2                         | 75/112 [00:00<00:00, 363.36it/s]




Converting measures from part 2: 100%|#############################################################################| 112/112 [00:00<00:00, 322.75it/s]


Processing part 3/4


Converting measures from part 3: 100%|###############################################################################| 66/66 [00:00<00:00, 946.83it/s]



Processing part 4/4


Converting measures from part 4: 100%|##############################################################################| 106/106 [00:01<00:00, 82.90it/s]


/content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/Back_In_Black.mid
Processing file /content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/Back_In_Black.mid
Processing part 1/8
/content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/You_Shook_Me_All_Night_Long.1.mid
Processing file /content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/You_Shook_Me_All_Night_Long.1.mid
Processing part 1/6


Converting measures from part 1:  64%|#################################################5                            | 68/107 [00:00<00:00, 352.75it/s]




Converting measures from part 1: 100%|#############################################################################| 107/107 [00:00<00:00, 275.39it/s]
Converting measures from part 2: 100%|#################################################################################| 1/1 [00:00<00:00, 465.67it/s]


Processing part 2/6
Processing part 3/6


Converting measures from part 3:  64%|#################################################9                            | 66/103 [00:00<00:00, 333.69it/s]




Converting measures from part 3: 100%|#############################################################################| 103/103 [00:00<00:00, 259.88it/s]


Processing part 4/6


Converting measures from part 4: 100%|###############################################################################| 93/93 [00:00<00:00, 444.86it/s]


Processing part 5/6


Converting measures from part 5: 100%|#################################################################################| 3/3 [00:00<00:00, 984.42it/s]


Processing part 6/6


Converting measures from part 6: 100%|#################################################################################| 2/2 [00:00<00:00, 877.01it/s]


/content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/Thunderstruck.mid
Processing file /content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/Thunderstruck.mid
Processing part 1/1


[<music21.beam.Beams <music21.beam.Beam 1/start>/<music21.beam.Beam 2/partial/right>/<music21.beam.Beam 3/partial/right>>, <music21.beam.Beams <music21.beam.Beam 1/continue>>, <music21.beam.Beams <music21.beam.Beam 1/stop>>, <music21.beam.Beams <music21.beam.Beam 1/start>>, <music21.beam.Beams <music21.beam.Beam 1/stop>>, <music21.beam.Beams <music21.beam.Beam 1/start>>, <music21.beam.Beams <music21.beam.Beam 1/stop>>, <music21.beam.Beams <music21.beam.Beam 1/start>/<music21.beam.Beam 2/partial/right>>, <music21.beam.Beams <music21.beam.Beam 1/stop>/<music21.beam.Beam 2/stop>>]
[<music21.beam.Beams <music21.beam.Beam 1/start>/<music21.beam.Beam 2/partial/right>/<music21.beam.Beam 3/partial/right>>, <music21.beam.Beams <music21.beam.Beam 1/continue>>, <music21.beam.Beams <music21.beam.Beam 1/stop>>, <music21.beam.Beams <music21.beam.Beam 1/start>>, <music21.beam.Beams <music21.beam.Beam 1/stop>>, <music21.beam.Beams <music21.beam.Beam 1/start>>, <music21.beam.Beams <music21.beam.Beam 1/

/content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/Highway_To_Hell.2.mid
Processing file /content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/Highway_To_Hell.2.mid
Processing part 1/5


Converting measures from part 1:  79%|##############################################################                 | 66/84 [00:00<00:00, 335.43it/s]

Acoustic Bass


Converting measures from part 1: 100%|###############################################################################| 84/84 [00:00<00:00, 318.49it/s]


Processing part 2/5


Converting measures from part 2:  78%|#############################################################8                 | 65/83 [00:00<00:00, 328.62it/s]

Fretless Bass


Converting measures from part 2: 100%|###############################################################################| 83/83 [00:00<00:00, 301.47it/s]


Processing part 3/5


Converting measures from part 3:  59%|##############################################5                                | 53/90 [00:00<00:00, 247.63it/s]




Converting measures from part 3: 100%|###############################################################################| 90/90 [00:00<00:00, 221.81it/s]


Processing part 4/5


Converting measures from part 4:  60%|###############################################4                               | 54/90 [00:00<00:00, 252.26it/s]




Converting measures from part 4: 100%|###############################################################################| 90/90 [00:00<00:00, 216.66it/s]


Processing part 5/5


Converting measures from part 5:  59%|##############################################5                                | 53/90 [00:00<00:00, 239.81it/s]




Converting measures from part 5: 100%|###############################################################################| 90/90 [00:00<00:00, 220.76it/s]


/content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/Back_In_Black.1.mid
Processing file /content/gdrive/My Drive/Kaggle/clean_midi/AC_DC/Back_In_Black.1.mid
Processing part 1/7


Converting measures from part 1:  25%|###################7                                                           | 23/92 [00:00<00:00, 113.69it/s]




Converting measures from part 1: 100%|###############################################################################| 92/92 [00:00<00:00, 108.06it/s]


Processing part 2/7


Converting measures from part 2:  25%|###################7                                                           | 23/92 [00:00<00:00, 112.33it/s]




Converting measures from part 2: 100%|###############################################################################| 92/92 [00:00<00:00, 111.60it/s]


Processing part 3/7


Converting measures from part 3:  39%|##############################9                                                | 36/92 [00:00<00:00, 168.85it/s]




Converting measures from part 3: 100%|###############################################################################| 92/92 [00:00<00:00, 223.57it/s]


Processing part 4/7


Converting measures from part 4:  43%|#################################7                                             | 32/75 [00:00<00:00, 161.93it/s]




Converting measures from part 4: 100%|###############################################################################| 75/75 [00:00<00:00, 224.06it/s]


Processing part 5/7


Converting measures from part 5: 100%|###############################################################################| 55/55 [00:00<00:00, 408.94it/s]



Processing part 6/7


Converting measures from part 6: 100%|###############################################################################| 55/55 [00:00<00:00, 346.75it/s]



Processing part 7/7


Converting measures from part 7:  46%|####################################3                                          | 35/76 [00:00<00:00, 179.45it/s]

StringInstrument


Converting measures from part 7: 100%|###############################################################################| 76/76 [00:00<00:00, 241.36it/s]


In [12]:
len(dataset)

16

## Preprocess data

```preprocess_bar(encoded_seq, n_in=32, n_out=32)``` uma barra (32 frames) codificada em multi-hot pre-processa essa barra de forma a gerar os valores splitados em X e y

In [13]:
def preprocess_bar(encoded_seq, n_in=32, n_out=32):
  # create lag copies of the sequence
  df = pd.DataFrame(encoded_seq)
  df = pd.concat([df.shift(n_in-i-1) for i in range(n_in)], axis=1)
  # drop rows with missing values
  df.dropna(inplace=True)
  # specify comumns for inout and output values
  values = df.values
  width = encoded_seq.shape[1]
  X = values[:, 0:width*(n_in-1)].reshape(n_in-1, width)
  y = values[:, width:].reshape(n_in-1, width)
  return X,y

```create_dataloader(dataset, batch_size=1)``` converte um dataset com n musicas em um dataloader. 

In [14]:
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

def create_dataloader(dataset, batch_size=1):
  X = []
  y = []
  n_songs = 0
  n_parts= 0
  n_bars= 0

  # create a two arrays X, y with bars
  for song in dataset:
    for part in song:
      for bar in part:
        xa, ya = preprocess_bar(bar)
        X.append(xa)
        y.append(ya)

  X = np.array(X)
  y = np.array(y)
  X = torch.from_numpy(X)
  y = torch.from_numpy(y)
  print(X.shape, y.shape)
  train_ds = TensorDataset(X, y)
  train_dl = DataLoader(train_ds, batch_size=1, shuffle=False)

  return train_dl
           

O dataloader é dividido em duas partes:

  1.   context
  2.   target

onde context são os valores que serão passados como entrada para o modelo - neste caso, esses valores serão 31 frames localizados em cada uma das 5010 barras das musicas e cada frame possui 88 notas.


In [15]:
train_dl = create_dataloader(dataset)

torch.Size([5010, 31, 88]) torch.Size([5010, 31, 88])


In [60]:
train_dl.dataset.tensors[0][200]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float64)

In [16]:
def create_vocab(dataset):
  vocab = []
  for song in dataset:
    for part in song:
      for bar in part:
        vocab.append(bar)

  vocab = np.array(vocab)
  vocab = vocab.reshape(vocab.shape[0]*vocab.shape[1], vocab.shape[2])
  vocab = np.unique(vocab, axis=0)
  
  print(vocab)
  
  return vocab

Temos 837 frames distintos

In [17]:
diff_frames = create_vocab(dataset)
print('\nvocab len: {}'.format(diff_frames.shape[0]))

[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]

vocab len: 626


## Model

Some important definitions

In [None]:
n_frames_input = 31
n_frames_output = 31
n_bars_input = len(train_dl.dataset.tensors[0]) # number of rows of the dataloader
bar_len = 31 # how many frames it's gonna take in a timestep
num_layers = 31
frame_len = 88
hidden_size = 88
num_epochs = 3
batch_size = 1
lr = 0.003
print('Number of bars in the input dataset: {}'.format(n_bars_input))

In [45]:
class RNN(nn.Module):
  def __init__(self, input_size, hidden_size, num_layers, output_size):
    super(RNN, self).__init__()
    self.hidden_size = hidden_size
    self.num_layers = num_layers

    #self.embed = nn.Embedding(input_size, hidden_size)
    self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=False)
    self.fc = nn.Linear(hidden_size, output_size)
    self.act = nn.Hardsigmoid()

  def forward(self, x, hidden, cell):

    # Passing in the input and hidden state into the model and obtaining outputs
    out, (hidden, cell) = self.lstm(x.unsqueeze(1), (hidden, cell))

    # Reshaping the outputs such that it can be fit into the fully connected layer
    out = self.fc(out.contiguous().view(-1, self.hidden_size))
    out = self.act(out)
    
    return out, (hidden, cell)

  def init_hidden(self, batch_size):
    # This method generates the first hidden state of zeros which we'll use in the forward pass
    # We'll send the tensor holding the hidden state to the device we specified earlier as well
    hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
    cell = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
    return hidden, cell

In [83]:
# Instantiate the model with hyperparameters
# We'll also set the model to the device that we defined earlier (default is CPU)
model = RNN(input_size=frame_len,
                   output_size=frame_len,
                   hidden_size=hidden_size,
                   num_layers=num_layers).to(device)

## Train

In [84]:
# converts one frame into torch tensor
def multi_hot_tensor(frame):
  tensor = torch.from_numpy(frame)
  return tensor

In [85]:
# retrieve data from dataloader
def get_sample(dataloader):

  input = torch.zeros(n_bars_input, bar_len, frame_len)
  target = torch.zeros(n_bars_input, bar_len, frame_len)

  for sample, (xb, yb) in enumerate(dataloader): # gets the samples
    input[sample] = xb
    target[sample] = yb
  
  return input, target

In [89]:
def train(model, optimizer, loss_fn, dataloader, batch_size=1, num_epochs=3):

  print("\nStarting training...")

  for epoch in range(1, num_epochs + 1):
    training_loss = 0.0

    print('> EPOCH #', epoch)

    input, target = get_sample(dataloader)
    input = input.to(device)
    target = target.to(device)

    for bar in tqdm(range(n_bars_input)):
      # Initialize hidden and cells
      hidden, cell = model.init_hidden(batch_size)

      # Generate predictions
      output, (hidden, cell) = model(input[bar,:], hidden, cell)

      # Compute the loss and backpropag         
      loss_step = loss_fn(output, target[bar, :])
      loss_step.backward() # Does backpropagation and calculates gradients
      optimizer.step() # Updates the weights accordingly
      optimizer.zero_grad() # Clears existing gradients from previous frame
      
      training_loss += loss_step.item()
    
    training_loss /= len(train_dl.dataset)
      
    if epoch%1 == 0:
      print('Epoch: {}/{}.............'.format(epoch, num_epochs), end=' ')
      print("Loss: {:.4f}".format(training_loss))

In [90]:
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.BCELoss()
train(model, optimizer, loss_fn, train_dl)


Starting training...
> EPOCH # 1


HBox(children=(FloatProgress(value=0.0, max=5010.0), HTML(value='')))


Epoch: 1/3............. Loss: 0.8174
> EPOCH # 2


HBox(children=(FloatProgress(value=0.0, max=5010.0), HTML(value='')))


Epoch: 2/3............. Loss: 0.8162
> EPOCH # 3


HBox(children=(FloatProgress(value=0.0, max=5010.0), HTML(value='')))


Epoch: 3/3............. Loss: 0.8162


## Test

In [112]:
def generate(model, initial_bar, predict_len=31, batch_size=1, temperature=0.85):
  output = []
  hidden, cell = model.init_hidden(batch_size)
  initial_bar = initial_bar.to(device)
  for _ in range(predict_len):
    out, (hidden, cell) = model(initial_bar, hidden, cell)
    output.append(out)
  
  return output

In [119]:
initial_bar = train_dl.dataset.tensors[0][3847]
initial_bar = initial_bar.to(torch.float)
print(f'Input: ')
torch.set_printoptions(threshold=10_000)
print(initial_bar)

output = generate(model, initial_bar, predict_len=1)
print(f'Output: ')
print(output)

Input: 
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0