In [1]:
import numpy as np
import torch
import os
from tqdm.notebook import tqdm

from deepnote import MusicRepr, Constants
from importlib import reload

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

seed_everything(42)

  rank_zero_deprecation(
Global seed set to 42


42

## Data

In [2]:
const = Constants(unit=4, num_tempo_bins=20, num_velocity_bins=20)

data_config = {
#     'data_dir' : '/home/soroosh/data/MIDI/pop909/train/',
#     'data_dir' : '/home/soroosh/data/MIDI/e-gmd-v1.0.0/midis_processed/',
    'data_dir' : '/home/soroosh/data/MIDI/lmd_processed/',
    'const' : const,
    'instruments' : ['piano', 'drums'],
    'mode' : 'remi',
    'max_files' : 10,
    'window_len' : 1024,
    'pad_value' : 0,
    'n_jobs' : 20
}

name = 'remi-small-v-lmd-win1024'
print('model name:',name)

model name: remi-small-v-lmd-win1024


In [39]:
import src.data
reload(src.data)
from src.data import MidiDataset

dataset = MidiDataset(**data_config)
n = len(dataset)
n

  0%|          | 0/10 [00:00<?, ?it/s]

85235

In [40]:
from torch.utils.data import DataLoader, random_split

t = int(0.1 * n)
td, vd = random_split(dataset, [n-t, t])
tl = DataLoader(dataset=td, batch_size=16, pin_memory=False, shuffle=True, num_workers=4, collate_fn=dataset.fn)
vl = DataLoader(dataset=vd, batch_size=32, pin_memory=False, shuffle=False, num_workers=4, collate_fn=dataset.fn)

In [5]:
b = next(iter(tl))
for k in b:
    print(k, b[k].shape)

X torch.Size([18, 1023])
X_len torch.Size([18])
labels torch.Size([18, 1023])


## Model

In [3]:
import src.config
reload(src.config)
from src.config import make_config

from transformers import GPT2Config, TransfoXLConfig

config = make_config(
    const,
    mode='remi',
    model='transformer',
    d_model=256, 
    max_len=10000,
    dropout=0.1, 
    lr=2e-4,
    tie_emb=False,
    pos_emb='relative', 
    n_layer=4, 
    n_head=8, 
    d_inner=256, 
    activation='gelu'
)
# config = {
#     'lr' : 1e-4,
#     'transformer': TransfoXLConfig(
#         vocab_size=len(const.all_tokens) + 1,
#         cutoffs=[],
#         d_model=256,
#         d_embed=256,
#         d_head=32,
#         n_head=8,
#         d_inner=256,
#         n_layer=4,
#         dropout=0.1,
#         clamp_len=512,
#         pad_token_id=len(const.all_tokens),
#         eos_token_id=1,
#         bos_token_id=0
#     )
#     'transformer': GPT2Config(
#         vocab_size=len(const.all_tokens),
#         n_positions=1024,
#         n_ctx=1024,
#         n_embd=256,
#         n_layer=4,
#         n_head=8,
#         n_inner=512,
#         pad_token_id=len(const.all_tokens),
#         eos_token_id=1,
#         bos_token_id=0
#     )
# }
config

{'lr': 0.0002,
 'embedding': {'d_model': 256,
  'dropout': 0.1,
  'max_len': 10000,
  'positional_embedding': 'relative',
  'n_vocab': 351},
 'head': {'d_model': 256, 'n_vocab': 351},
 'transformer': {'d_model': 256,
  'n_layer': 4,
  'n_head': 8,
  'd_inner': 256,
  'dropout': 0.1,
  'activation': 'gelu'},
 'tie_emb': False}

In [43]:
from src.models.remi import RemiLinearTransformer, RemiHFTransformer, RemiTransformer

model = RemiTransformer(config)
# model = RemiTransformer.load_from_checkpoint(f'weights/{name}/last.ckpt', config=config)
model.count_parameters()

1763167

In [44]:
logger = TensorBoardLogger(save_dir='logs/', name=name)
lr_logger = LearningRateMonitor(logging_interval='step')
checkpoint = ModelCheckpoint(
    dirpath=f'weights/{name}/', 
    filename='{epoch}-{val_loss:.2f}', 
    monitor='train_loss',
    save_top_k=5, 
    period=1
)

trainer = Trainer(
    benchmark=True, 
    gpus=1, 
    accumulate_grad_batches=1,
    logger=logger, 
    max_epochs=5,
    callbacks=[checkpoint, lr_logger]
)

  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [45]:
trainer.fit(model, tl, vl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type               | Params
---------------------------------------------------
0 | loss_func   | CrossEntropyLoss   | 0     
1 | embedding   | RemiEmbedding      | 89.9 K
2 | transformer | VanillaTransformer | 1.6 M 
3 | head        | RemiHead           | 90.2 K
---------------------------------------------------
1.8 M     Trainable params
0         Non-trainable params
1.8 M     Total params
7.053     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 42


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [46]:
trainer.save_checkpoint(f'weights/{name}/last.ckpt')

## generate

In [4]:
from src.models.remi import RemiHFTransformer, RemiTransformer


gen_model = RemiTransformer.load_from_checkpoint(f"weights/{name}/last.ckpt", config=config)

In [48]:
# gen_conf = {
#     'p_ttype' : 1.,
#     't_ttype' : 1.,
#     'p_barbeat' : .9,
#     't_barbeat' : .7,
#     'p_tempo' : 1.,
#     't_tempo' : .7,
#     'p_chord' : 1.,
#     't_chord' : .9,
#     'p_pitch' : .8,
#     't_pitch' : .7,
#     'p_duration' : .8,
#     't_duration' : .7,
#     'p_velocity' : 1.,
#     't_velocity' : 1.,
# }

In [51]:
path = data_config['data_dir']
import random
idx = random.randint(0, 21000)
print('idx: ', idx)
seq = MusicRepr.from_file(path + os.listdir(path)[idx], const=const)
tracks = seq.separate_tracks()
tracks = dict([(k,v) for k,v in tracks.items() if k in ['piano', 'drums']])
seq = MusicRepr.merge_tracks(tracks)
print(seq.get_instruments())
prompt = MusicRepr.concatenate(seq.get_bars()[:10])
len(prompt), len(prompt.to_remi())

idx:  8024
['drums', 'piano']


(158, 479)

In [56]:
gen_remi = gen_model.generate(prompt=prompt, max_len=1000, window=500, cuda=True, top_p=.95, temperature=.8)
gen_remi.shape

  0%|          | 0/1000 [00:00<?, ?it/s]

(1479,)

In [57]:
tokens = [const.all_tokens[idx] for idx in gen_remi]
print(tokens[:10])
gen_seq = MusicRepr.from_string(' '.join(tokens), const=const)
len(gen_seq), gen_seq.get_bar_count()

['Bar', 'BeatTempo_157', 'Bar', 'BeatChord_G_m7', 'BeatPosition_8', 'BeatChord_C_M', 'Bar', 'Bar', 'BeatPosition_14', 'NoteInstFamily_drums']


(477, 17)

In [54]:
# gen_cp = np.concatenate(
#     [
#         gen_cp[:,:4], 
#         np.ones(shape=(gen_cp.shape[0],1))*const.instruments.index('piano'), 
#         gen_cp[:, 4:]
#     ], 
#     axis=1
# )
# gen_seq = MusicRepr.from_cp(gen_cp.astype(int), const=const)

In [58]:
gen_seq.to_midi('v-gen-multi-cont.mid')

ticks per beat: 384
max tick: 25920
tempo changes: 1
time sig: 1
key sig: 0
markers: 16
lyrics: False
instruments: 2