In [1]:
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
import torch
import pandas as pd
import torch.nn as nn
import torch.optim as optim
# from torch.distributions.categorical import Categorical
from EPMS.serialization import SETTINGS
import os

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

cuda


In [3]:
class EPMSDataset(Dataset):
    def __init__(self, table_dir, instrument):
        self.instrument = instrument
        self.dataframes = []
        for filepath in os.listdir(table_dir):
            if not filepath.endswith('.pkl'):
                continue
            self.dataframes.append(pd.read_pickle(f'{table_dir}/{filepath}'))

    def get_dataframe_notes(self, dataframe):
        dataframe_instrument = dataframe[dataframe.INSTRUMENT == self.instrument]
        initial_note = len(dataframe.columns) - SETTINGS["KEYBOARD_SIZE"]
        dataframe_np = dataframe_instrument.iloc[:, initial_note:].astype("float").to_numpy()
        return dataframe_np

    def __len__(self):
        return len(self.dataframes)

    def __getitem__(self, idx):
        notes = np.ceil(self.get_dataframe_notes(self.dataframes[idx]))
        return torch.tensor(notes[:-1]), torch.tensor(notes[1:]).float()


In [4]:
batch_size = 1
dataset_name = 'C_major_scale'
dataset = EPMSDataset(dataset_name + "/", "Piano")
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [5]:
class PapagaioLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, n_layers):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.embedding_layer = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim*vocab_size, hidden_size, n_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, h0=None, c0=None):
        x = self.embedding_layer(x.long())
        x = x.flatten(2, 3)
        if h0 is None or c0 is None:
            out, (hf, cf) = self.lstm(x)
        else:
            out, (hf, cf) = self.lstm(x, (h0, c0))
        out = self.linear(out)
        out = self.sigmoid(out)
        return out, (hf, cf)

In [6]:
vocab_size = SETTINGS["KEYBOARD_SIZE"]
embedding_dim = 64
hidden_size = 1024
n_layers = 1
epochs = 100
lr = 0.001

In [7]:
papagaio_lstm = PapagaioLSTM(vocab_size, embedding_dim, hidden_size, n_layers).to(device)
optimizer = optim.Adam(papagaio_lstm.parameters(), lr)
loss_fn = nn.BCELoss()

for epoch in tqdm(range(epochs)):
    total_loss = 0
    for batch in dataloader:
        x = batch[0].to(device)
        y = batch[1].to(device)
        optimizer.zero_grad()
        output, _ = papagaio_lstm(x)
        loss = loss_fn(output, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()/batch_size

    print(f'Epoch: {epoch+1}')
    print(f'Loss: {total_loss}')

  1%|          | 1/100 [00:01<02:11,  1.33s/it]

Epoch: 1
Loss: 0.7222133278846741


  2%|▏         | 2/100 [00:01<01:08,  1.43it/s]

Epoch: 2
Loss: 0.30464601516723633


  3%|▎         | 3/100 [00:01<00:47,  2.05it/s]

Epoch: 3
Loss: 0.2105514258146286


  4%|▍         | 4/100 [00:02<00:36,  2.60it/s]

Epoch: 4
Loss: 0.14869393408298492


  5%|▌         | 5/100 [00:02<00:31,  3.06it/s]

Epoch: 5
Loss: 0.10918699204921722


  6%|▌         | 6/100 [00:02<00:27,  3.46it/s]

Epoch: 6
Loss: 0.08438749611377716


  7%|▋         | 7/100 [00:02<00:24,  3.77it/s]

Epoch: 7
Loss: 0.06881191581487656


  8%|▊         | 8/100 [00:02<00:24,  3.78it/s]

Epoch: 8
Loss: 0.058794550597667694


  9%|▉         | 9/100 [00:03<00:22,  4.00it/s]

Epoch: 9
Loss: 0.05212625116109848


 10%|█         | 10/100 [00:03<00:21,  4.15it/s]

Epoch: 10
Loss: 0.04750671237707138


 11%|█         | 11/100 [00:03<00:20,  4.27it/s]

Epoch: 11
Loss: 0.04418046027421951


 12%|█▏        | 12/100 [00:03<00:20,  4.33it/s]

Epoch: 12
Loss: 0.0417046844959259


 13%|█▎        | 13/100 [00:04<00:19,  4.41it/s]

Epoch: 13
Loss: 0.03982410952448845


 14%|█▍        | 14/100 [00:04<00:19,  4.40it/s]

Epoch: 14
Loss: 0.038400840014219284


 15%|█▌        | 15/100 [00:04<00:18,  4.48it/s]

Epoch: 15
Loss: 0.037352055311203


 16%|█▌        | 16/100 [00:04<00:18,  4.46it/s]

Epoch: 16
Loss: 0.036602262407541275


 17%|█▋        | 17/100 [00:04<00:18,  4.49it/s]

Epoch: 17
Loss: 0.036069147288799286


 18%|█▊        | 18/100 [00:05<00:18,  4.53it/s]

Epoch: 18
Loss: 0.03567390888929367


 19%|█▉        | 19/100 [00:05<00:17,  4.51it/s]

Epoch: 19
Loss: 0.0353536494076252


 20%|██        | 20/100 [00:05<00:17,  4.46it/s]

Epoch: 20
Loss: 0.03506641089916229


 21%|██        | 21/100 [00:05<00:17,  4.48it/s]

Epoch: 21
Loss: 0.03479310870170593


 22%|██▏       | 22/100 [00:06<00:17,  4.53it/s]

Epoch: 22
Loss: 0.03453643620014191


 23%|██▎       | 23/100 [00:06<00:16,  4.57it/s]

Epoch: 23
Loss: 0.0343131497502327


 24%|██▍       | 24/100 [00:06<00:16,  4.55it/s]

Epoch: 24
Loss: 0.03414110094308853


 25%|██▌       | 25/100 [00:06<00:16,  4.43it/s]

Epoch: 25
Loss: 0.03402746841311455


 26%|██▌       | 26/100 [00:06<00:16,  4.48it/s]

Epoch: 26
Loss: 0.03396359831094742


 27%|██▋       | 27/100 [00:07<00:16,  4.49it/s]

Epoch: 27
Loss: 0.03392831236124039


 28%|██▊       | 28/100 [00:07<00:15,  4.55it/s]

Epoch: 28
Loss: 0.033897485584020615


 29%|██▉       | 29/100 [00:07<00:15,  4.58it/s]

Epoch: 29
Loss: 0.03385449945926666


 30%|███       | 30/100 [00:07<00:15,  4.42it/s]

Epoch: 30
Loss: 0.03379543498158455


 31%|███       | 31/100 [00:08<00:18,  3.81it/s]

Epoch: 31
Loss: 0.03372655808925629
Epoch: 32

 32%|███▏      | 32/100 [00:08<00:16,  4.03it/s]


Loss: 0.03365735709667206


 33%|███▎      | 33/100 [00:08<00:15,  4.21it/s]

Epoch: 33
Loss: 0.033594489097595215


 34%|███▍      | 34/100 [00:08<00:15,  4.23it/s]

Epoch: 34
Loss: 0.03353990986943245


 35%|███▌      | 35/100 [00:09<00:15,  4.31it/s]

Epoch: 35
Loss: 0.03349295258522034


 36%|███▌      | 36/100 [00:09<00:14,  4.38it/s]

Epoch: 36
Loss: 0.03345353528857231


 37%|███▋      | 37/100 [00:09<00:14,  4.47it/s]

Epoch: 37
Loss: 0.03342341259121895


 38%|███▊      | 38/100 [00:09<00:14,  4.40it/s]

Epoch: 38
Loss: 0.033404890447854996


 39%|███▉      | 39/100 [00:09<00:13,  4.39it/s]

Epoch: 39
Loss: 0.033398017287254333
Epoch: 40

 40%|████      | 40/100 [00:10<00:13,  4.48it/s]


Loss: 0.033398792147636414


 41%|████      | 41/100 [00:10<00:13,  4.50it/s]

Epoch: 41
Loss: 0.0334000326693058


 42%|████▏     | 42/100 [00:10<00:12,  4.54it/s]

Epoch: 42
Loss: 0.03339441493153572


 43%|████▎     | 43/100 [00:10<00:12,  4.54it/s]

Epoch: 43
Loss: 0.03337807208299637


 44%|████▍     | 44/100 [00:11<00:12,  4.57it/s]

Epoch: 44
Loss: 0.03335237503051758


 45%|████▌     | 45/100 [00:11<00:11,  4.60it/s]

Epoch: 45
Loss: 0.0333230122923851


 46%|████▌     | 46/100 [00:11<00:11,  4.51it/s]

Epoch: 46
Loss: 0.033296968787908554


 47%|████▋     | 47/100 [00:11<00:11,  4.46it/s]

Epoch: 47
Loss: 0.03327929973602295


 48%|████▊     | 48/100 [00:11<00:11,  4.48it/s]

Epoch: 48
Loss: 0.03327123820781708


 49%|████▉     | 49/100 [00:12<00:11,  4.54it/s]

Epoch: 49
Loss: 0.03327039256691933


 50%|█████     | 50/100 [00:12<00:11,  4.52it/s]

Epoch: 50
Loss: 0.03327241539955139


 51%|█████     | 51/100 [00:12<00:10,  4.56it/s]

Epoch: 51
Loss: 0.033273134380578995


 52%|█████▏    | 52/100 [00:12<00:10,  4.57it/s]

Epoch: 52
Loss: 0.03327011317014694


 53%|█████▎    | 53/100 [00:13<00:10,  4.58it/s]

Epoch: 53
Loss: 0.03326316922903061


 54%|█████▍    | 54/100 [00:13<00:09,  4.62it/s]

Epoch: 54
Loss: 0.03325380012392998


 55%|█████▌    | 55/100 [00:13<00:09,  4.61it/s]

Epoch: 55
Loss: 0.03324412927031517


 56%|█████▌    | 56/100 [00:13<00:09,  4.52it/s]

Epoch: 56
Loss: 0.033235885202884674


 57%|█████▋    | 57/100 [00:13<00:09,  4.47it/s]

Epoch: 57
Loss: 0.03322984278202057


 58%|█████▊    | 58/100 [00:14<00:09,  4.49it/s]

Epoch: 58
Loss: 0.03322583809494972


 59%|█████▉    | 59/100 [00:14<00:09,  4.51it/s]

Epoch: 59
Loss: 0.03322311118245125


 60%|██████    | 60/100 [00:14<00:08,  4.53it/s]

Epoch: 60
Loss: 0.03322076424956322


 61%|██████    | 61/100 [00:14<00:08,  4.54it/s]

Epoch: 61
Loss: 0.033218156546354294


 62%|██████▏   | 62/100 [00:15<00:08,  4.56it/s]

Epoch: 62
Loss: 0.03321504220366478


 63%|██████▎   | 63/100 [00:15<00:08,  4.58it/s]

Epoch: 63
Loss: 0.03321158513426781


 64%|██████▍   | 64/100 [00:15<00:07,  4.58it/s]

Epoch: 64
Loss: 0.033208176493644714


 65%|██████▌   | 65/100 [00:15<00:07,  4.51it/s]

Epoch: 65
Loss: 0.03320514038205147


 66%|██████▌   | 66/100 [00:15<00:07,  4.29it/s]

Epoch: 66
Loss: 0.03320258483290672


 67%|██████▋   | 67/100 [00:16<00:07,  4.40it/s]

Epoch: 67
Loss: 0.03320032358169556


 68%|██████▊   | 68/100 [00:16<00:07,  4.46it/s]

Epoch: 68
Loss: 0.03319801390171051


 69%|██████▉   | 69/100 [00:16<00:06,  4.44it/s]

Epoch: 69
Loss: 0.03319539874792099


 70%|███████   | 70/100 [00:16<00:06,  4.41it/s]

Epoch: 70
Loss: 0.0331924743950367


 71%|███████   | 71/100 [00:17<00:06,  4.48it/s]

Epoch: 71
Loss: 0.03318953886628151


 72%|███████▏  | 72/100 [00:17<00:06,  4.53it/s]

Epoch: 72
Loss: 0.03318694606423378
Epoch: 73

 73%|███████▎  | 73/100 [00:17<00:05,  4.59it/s]


Loss: 0.033184923231601715


 74%|███████▍  | 74/100 [00:17<00:05,  4.58it/s]

Epoch: 74
Loss: 0.03318339213728905


 75%|███████▌  | 75/100 [00:17<00:05,  4.47it/s]

Epoch: 75
Loss: 0.0331820547580719


 76%|███████▌  | 76/100 [00:18<00:06,  3.80it/s]

Epoch: 76
Loss: 0.03318058326840401


 77%|███████▋  | 77/100 [00:18<00:05,  3.91it/s]

Epoch: 77
Loss: 0.03317878022789955


 78%|███████▊  | 78/100 [00:18<00:05,  3.75it/s]

Epoch: 78
Loss: 0.03317667543888092


 79%|███████▉  | 79/100 [00:19<00:05,  3.87it/s]

Epoch: 79
Loss: 0.033174436539411545


 80%|████████  | 80/100 [00:19<00:05,  3.93it/s]

Epoch: 80
Loss: 0.03317226842045784


 81%|████████  | 81/100 [00:19<00:05,  3.64it/s]

Epoch: 81
Loss: 0.033170316368341446


 82%|████████▏ | 82/100 [00:20<00:05,  3.17it/s]

Epoch: 82
Loss: 0.03316861763596535


 83%|████████▎ | 83/100 [00:20<00:04,  3.48it/s]

Epoch: 83
Loss: 0.03316711634397507


 84%|████████▍ | 84/100 [00:20<00:04,  3.71it/s]

Epoch: 84
Loss: 0.03316571190953255


 85%|████████▌ | 85/100 [00:20<00:03,  3.79it/s]

Epoch: 85
Loss: 0.03316429257392883


 86%|████████▌ | 86/100 [00:21<00:04,  3.41it/s]

Epoch: 86
Loss: 0.033162765204906464


 87%|████████▋ | 87/100 [00:21<00:03,  3.35it/s]

Epoch: 87
Loss: 0.03316112980246544


 88%|████████▊ | 88/100 [00:21<00:03,  3.49it/s]

Epoch: 88
Loss: 0.03315942734479904


 89%|████████▉ | 89/100 [00:21<00:02,  3.69it/s]

Epoch: 89
Loss: 0.03315773978829384


 90%|█████████ | 90/100 [00:22<00:02,  3.85it/s]

Epoch: 90
Loss: 0.033156126737594604


 91%|█████████ | 91/100 [00:22<00:02,  4.03it/s]

Epoch: 91
Loss: 0.03315459191799164


 92%|█████████▏| 92/100 [00:22<00:01,  4.21it/s]

Epoch: 92
Loss: 0.033153120428323746


 93%|█████████▎| 93/100 [00:22<00:01,  4.27it/s]

Epoch: 93
Loss: 0.033151667565107346


 94%|█████████▍| 94/100 [00:23<00:01,  3.72it/s]

Epoch: 94
Loss: 0.033150214701890945


 95%|█████████▌| 95/100 [00:23<00:01,  3.92it/s]

Epoch: 95
Loss: 0.033148761838674545


 96%|█████████▌| 96/100 [00:23<00:00,  4.10it/s]

Epoch: 96
Loss: 0.03314732387661934


 97%|█████████▋| 97/100 [00:23<00:00,  4.18it/s]

Epoch: 97
Loss: 0.03314588963985443


 98%|█████████▊| 98/100 [00:24<00:00,  4.28it/s]

Epoch: 98
Loss: 0.03314445912837982
Epoch: 99

 99%|█████████▉| 99/100 [00:24<00:00,  4.39it/s]


Loss: 0.03314301371574402


100%|██████████| 100/100 [00:24<00:00,  4.08it/s]

Epoch: 100
Loss: 0.03314156457781792





In [8]:
# torch.save(papagaio_lstm.state_dict(), f'models/papagaio_model_{dataset_name}_{epochs}epochs_{lr}lr.pt')

In [9]:
# papagaio_lstm.load_state_dict(torch.load('models/papagaio_model.pt', map_location=device))