In [1]:
import numpy as np
import torch
import os
from tqdm.notebook import tqdm

from deepnote import MusicRepr, Constants
from importlib import reload

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

seed_everything(42)

  rank_zero_deprecation(
Global seed set to 42


42

## Data

In [2]:
const = Constants(unit=4, num_tempo_bins=20, num_velocity_bins=20)

data_config = {
    'data_dir' : '/home/soroosh/data/MIDI/pop909/train/',
#     'data_dir' : '/home/soroosh/data/MIDI/e-gmd-v1.0.0/midis_processed/',
#     'data_dir' : '/home/soroosh/data/MIDI/lmd_processed/',
    'const' : const,
    'instruments' : ['piano', 'drums'],
    'mode' : 'cp',
    'max_files' : 10,
    'window_len' : 1024,
    'pad_value' : 0,
    'n_jobs' : 20
}

name = 'cp-small-v-pop-win1024'
print('model name:',name)

model name: cp-small-v-pop-win1024


In [3]:
import src.data
reload(src.data)
from src.data import MidiDataset

dataset = MidiDataset(**data_config)
n = len(dataset)
n

  0%|          | 0/10 [00:00<?, ?it/s]

15734

In [4]:
from torch.utils.data import DataLoader, random_split

t = int(0.1 * n)
td, vd = random_split(dataset, [n-t, t])
tl = DataLoader(dataset=td, batch_size=16, pin_memory=False, shuffle=True, num_workers=4, collate_fn=dataset.fn)
vl = DataLoader(dataset=vd, batch_size=32, pin_memory=False, shuffle=False, num_workers=4, collate_fn=dataset.fn)

In [5]:
b = next(iter(tl))
for k in b:
    print(k, b[k].shape, b[k].device)

X torch.Size([16, 1023, 8]) cpu
X_len torch.Size([16]) cpu
labels torch.Size([16, 1023, 8]) cpu


## Model

In [6]:
import src.config
reload(src.config)
from src.config import make_config
from transformers import GPT2Config, TransfoXLConfig


config = make_config(
    const,
    mode='cp',
    model='transformer',
    d_model=256, 
    max_len=10000,
    dropout=0.1, 
    lr=2e-4,
    tie_emb=False,
    pos_emb='relative', 
    n_layer=4, 
    n_head=8, 
    d_inner=256, 
    activation='gelu'
)

# config = make_config(
#     const,
#     mode='remi',
#     model='transformer',
#     d_model=256, 
#     max_len=10000,
#     dropout=0.1, 
#     lr=2e-4,
#     tie_emb=False,
#     pos_emb='relative', 
#     n_layer=4, 
#     n_head=8, 
#     d_inner=256, 
#     activation='gelu'
# )

# config = {
#     'lr' : 1e-4,
#     'transformer': TransfoXLConfig(
#         vocab_size=len(const.all_tokens) + 1,
#         cutoffs=[],
#         d_model=256,
#         d_embed=256,
#         d_head=32,
#         n_head=8,
#         d_inner=256,
#         n_layer=4,
#         dropout=0.1,
#         clamp_len=512,
#         pad_token_id=len(const.all_tokens),
#         eos_token_id=1,
#         bos_token_id=0
#     )
# }
config

{'lr': 0.0002,
 'embedding': {'d_model': 256,
  'dropout': 0.1,
  'max_len': 10000,
  'positional_embedding': 'relative',
  'attributes': ['ttype',
   'barbeat',
   'tempo',
   'chord',
   'inst_family',
   'pitch',
   'duration',
   'velocity'],
  'n_tokens': {'ttype': 2,
   'barbeat': 16,
   'tempo': 21,
   'chord': 133,
   'inst_family': 17,
   'pitch': 128,
   'duration': 16,
   'velocity': 20},
  'emb_sizes': {'ttype': 8,
   'barbeat': 32,
   'tempo': 32,
   'chord': 128,
   'inst_family': 32,
   'pitch': 128,
   'duration': 32,
   'velocity': 32}},
 'head': {'d_model': 256,
  'attributes': ['ttype',
   'barbeat',
   'tempo',
   'chord',
   'inst_family',
   'pitch',
   'duration',
   'velocity'],
  'n_tokens': {'ttype': 2,
   'barbeat': 16,
   'tempo': 21,
   'chord': 133,
   'inst_family': 17,
   'pitch': 128,
   'duration': 16,
   'velocity': 20},
  'emb_sizes': {'ttype': 8,
   'barbeat': 32,
   'tempo': 32,
   'chord': 128,
   'inst_family': 32,
   'pitch': 128,
   'duration':

In [7]:
# from src.models.remi import RemiLinearTransformer, RemiHFTransformer, RemiTransformer
from src.models.cp import CPSimpleTransformer

model = CPSimpleTransformer(config)
# model = RemiTransformer.load_from_checkpoint(f'weights/{name}/last.ckpt', config=config)
model.count_parameters()

1818929

In [8]:
logger = TensorBoardLogger(save_dir='logs/', name=name)
lr_logger = LearningRateMonitor(logging_interval='step')
checkpoint = ModelCheckpoint(
    dirpath=f'weights/{name}/', 
    filename='{epoch}-{val_loss:.2f}', 
    monitor='train_loss',
    save_top_k=5, 
    period=1
)

trainer = Trainer(
    benchmark=True, 
    gpus=1, 
    accumulate_grad_batches=1,
    logger=logger, 
    max_epochs=100,
    callbacks=[checkpoint, lr_logger]
)

  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, tl, vl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type               | Params
---------------------------------------------------
0 | loss_func   | CrossEntropyLoss   | 0     
1 | embedding   | CPEmbedding        | 145 K 
2 | transformer | VanillaTransformer | 1.6 M 
3 | head        | CPSimpleHead       | 90.7 K
---------------------------------------------------
1.8 M     Trainable params
0         Non-trainable params
1.8 M     Total params
7.276     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 42


Training: 0it [00:00, ?it/s]

In [None]:
trainer.save_checkpoint(f'weights/{name}/last.ckpt')

## generate

In [None]:
# from src.models.remi import RemiHFTransformer, RemiTransformer
from src.models.cp import CPLinearTransformer, CPTransformer


gen_model = CPTransformer.load_from_checkpoint(f"weights/{name}/last.ckpt", config=config)

In [None]:
# path = data_config['data_dir']
# import random
# idx = random.randint(0, 21000)
# print('idx: ', idx)
# seq = MusicRepr.from_file(path + os.listdir(path)[idx], const=const)
# tracks = seq.separate_tracks()
# tracks = dict([(k,v) for k,v in tracks.items() if k in ['piano', 'drums']])
# seq = MusicRepr.merge_tracks(tracks)
# print(seq.get_instruments())
# prompt = MusicRepr.concatenate(seq.get_bars()[:10])
# len(prompt), len(prompt.to_remi())

In [None]:
gen_conf = {
    'p_ttype' : 1.,
    't_ttype' : .8,
    'p_barbeat' : .8,
    't_barbeat' : .8,
    'p_tempo' : .8,
    't_tempo' : .8,
    'p_chord' : .8,
    't_chord' : .8,
    'p_inst_family' : 0.8,
    't_inst_family' : 0.8,
    'p_pitch' : .8,
    't_pitch' : .8,
    'p_duration' : .8,
    't_duration' : .8,
    'p_velocity' : .8,
    't_velocity' : .8,
}

gen_cp = gen_model.generate(prompt=None, max_len=500, window=500, cuda=True, gen_conf=gen_conf)
gen_cp.shape

In [None]:
gen_seq = MusicRepr.from_cp(gen_cp.astype(int), const=const)

In [None]:
# gen_remi = gen_model.generate(prompt=None, max_len=1000, window=500, cuda=True, top_p=.9, temperature=.7)
# print(gen_remi.shape)

# tokens = [const.all_tokens[idx] for idx in gen_remi]
# print(tokens[:10])
# gen_seq = MusicRepr.from_string(' '.join(tokens), const=const)
# len(gen_seq), gen_seq.get_bar_count()

In [None]:
gen_seq.to_midi('cp-v-drums.mid')

In [None]:
len(gen_seq), gen_seq.get_bar_count()

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(20,5))
plt.imshow(gen_seq.to_pianoroll(add_tempo_chord=False)['drums'])