In [1]:
import math
import numpy as np
import torch
from importlib import reload
from torch import nn
from tqdm.notebook import tqdm
from deepnote import MusicRepr, Constants
import os

import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

seed_everything(42)

  rank_zero_deprecation(
Global seed set to 42


42

## dataset

In [2]:
const = Constants(unit=4, num_tempo_bins=20, num_velocity_bins=20)

data_config = {
    'data_dir' : '/home/soroosh/data/MIDI/pop909/train/',
    'const' : const,
    'mode' : 'remi',
    'instruments' : ['piano'],
    'max_files' : 2,
    'window_len' : 4096,
    'n_jobs' : 20
}

name = 'remi-rnn-pop909-win512'
print('model name:',name)

model name: remi-rnn-pop909-win512


In [3]:
import src.data
reload(src.data)
from src.data import MidiDataset

dataset = MidiDataset(**data_config)
n = len(dataset)
n

  0%|          | 0/2 [00:00<?, ?it/s]

7639

In [4]:
from torch.utils.data import DataLoader, random_split

t = int(0.1 * n)
td, vd = random_split(dataset, [n-t, t])
tl = DataLoader(dataset=td, batch_size=64, pin_memory=True, shuffle=True, num_workers=4, collate_fn=dataset.fn)
vl = DataLoader(dataset=vd, batch_size=64, pin_memory=True, shuffle=False, num_workers=4, collate_fn=dataset.fn)

In [5]:
b = next(iter(tl))
for k in b:
    print(k, b[k].shape)

X torch.Size([64, 511])
X_len torch.Size([64])
labels torch.Size([64, 511])


## model

In [6]:
import src.config
reload(src.config)
from src.config import make_config

config = make_config(
    const,
    mode='remi',
    model='rnn',
    d_model=256, 
    bidirectional=True,
    dropout=0.1, 
    lr=2e-4,
    tie_emb=False,
    pos_emb=True, 
    n_layer=4, 
)
config['head']['d_model'] = 512
config

{'lr': 0.0002,
 'embedding': {'d_model': 256,
  'dropout': 0.1,
  'max_len': 10000,
  'pos_emb': True,
  'n_vocab': 351},
 'head': {'d_model': 512, 'n_vocab': 351},
 'rnn': {'d_model': 256, 'n_layer': 4, 'dropout': 0.1, 'bidirectional': True},
 'tie_emb': False}

In [7]:
from src.models.remi import RemiRNN

model = RemiRNN(config)
print(model.count_parameters())
model

4607583


RemiRNN(
  (loss_func): CrossEntropyLoss()
  (embedding): RemiEmbedding(
    (emb): Embedding(351, 256)
    (pos_emb): PositionalEncoding()
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (rnn): VanillaRNN(
    (rnn): GRU(256, 256, num_layers=4, batch_first=True, dropout=0.1, bidirectional=True)
  )
  (head): RemiHead(
    (head): Linear(in_features=512, out_features=351, bias=True)
  )
)

In [8]:
logits, loss = model(b['X'], b['X_len'], b['labels'])
loss, logits.shape

(tensor(5.8435, grad_fn=<DivBackward0>), torch.Size([64, 511, 351]))

## train

In [9]:
logger = TensorBoardLogger(save_dir='logs/', name=name)
lr_logger = LearningRateMonitor(logging_interval='step')
checkpoint = ModelCheckpoint(
    dirpath=f'weights/{name}/', 
    filename='{epoch}-{val_loss:.2f}', 
    monitor='train_loss',
    save_top_k=5, 
    period=1
)

trainer = Trainer(
    benchmark=True, 
    gpus=1, 
    accumulate_grad_batches=1,
    logger=logger, 
    max_epochs=1,
    callbacks=[checkpoint, lr_logger]
)

  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [10]:
trainer.fit(model, tl, vl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | loss_func | CrossEntropyLoss | 0     
1 | embedding | RemiEmbedding    | 89.9 K
2 | rnn       | VanillaRNN       | 4.3 M 
3 | head      | RemiHead         | 180 K 
-----------------------------------------------
4.6 M     Trainable params
0         Non-trainable params
4.6 M     Total params
18.430    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 42


Training: 0it [00:00, ?it/s]

In [11]:
trainer.save_checkpoint(f'weights/{name}/last.ckpt')

## generate

In [12]:
import src.models.remi.rnn
reload(src.models.remi.rnn)
from src.models.remi.rnn import RemiRNN

gen_model = RemiRNN.load_from_checkpoint(f'weights/{name}/last.ckpt', config=config)

In [22]:
path = data_config['data_dir']
seq = MusicRepr.from_file(path + os.listdir(path)[0], const=const)
prompt = MusicRepr.concatenate(seq.get_bars()[:10])
len(prompt)

121

In [23]:
gen_remi = model.generate(prompt=prompt, max_len=300, cuda=True, top_p=0.9, temperature=0.8)
gen_remi.shape

  0%|          | 0/300 [00:00<?, ?it/s]

(631,)

In [24]:
tokens = [const.all_tokens[idx] for idx in gen_remi]
print(tokens[:10])

['Bar', 'BeatTempo_115', 'BeatPosition_12', 'BeatTempo_30', 'Bar', 'BeatTempo_115', 'BeatChord_D#_m', 'NoteInstFamily_piano', 'NotePitch_63', 'NoteDuration_8']


In [25]:
gen_seq = MusicRepr.from_string(' '.join(tokens), const=const)
len(gen_seq)

202

In [26]:
gen_seq.to_midi('test.mid')

ticks per beat: 480
max tick: 24840
tempo changes: 43
time sig: 1
key sig: 0
markers: 6
lyrics: False
instruments: 1