In [1]:
import numpy as np
import torch
from torch import nn
import os
from tqdm.notebook import tqdm

from deepnote import MusicRepr, Constants
from importlib import reload

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

seed_everything(42)

  rank_zero_deprecation(
Global seed set to 42


42

## data

In [3]:
const = Constants(unit=4, num_tempo_bins=20, num_velocity_bins=20)

data_config = {
#     'data_dir' : '/home/soroosh/data/MIDI/pop909/train/',
#     'data_dir' : '/home/soroosh/data/MIDI/e-gmd-v1.0.0/midis_processed/',
    'data_dir' : '/home/soroosh/data/MIDI/lmd_processed/',
    'const' : const,
    'src_instruments' : ['piano', 'drums', 'guitar'],
    'trg_instruments' : ['piano', 'drums', 'guitar'],
    'max_files' : 1000,
    'window_len' : 2,
    'max_len' : 2048,
    'pad_value' : 0,
    'n_jobs' : 20
}

name = 'encdec-lmd-win2-piano-guitar-drums'
print('model name:',name)

model name: encdec-lmd-win2-piano-guitar-drums


In [4]:
import src.data
reload(src.data)
from src.data import MidiDataset, get_dataloaders

dataset = MidiDataset(**data_config)
n = len(dataset)
n, len(dataset.lens)

  0%|          | 0/1000 [00:00<?, ?it/s]

(84138, 859)

In [5]:
sample = dataset[100]
# for k in sample:
#     print(k, len(sample[k]))
for inst in sample:
    print(inst)
    for k in sample[inst]:
        print('  ',k, len(sample[inst][k]))

drums
   src 92
   trg 76
guitar
   src 76
   trg 92


In [6]:
tl, vl = get_dataloaders(dataset, batch_size=2, n_jobs=2)

In [20]:
b = next(iter(vl))
for inst in b:
    print(inst)
    for k in b[inst]:
        print('   ', k, b[inst][k].shape)

drums
    src torch.Size([2, 233])
    trg torch.Size([2, 151])
    src_len torch.Size([2])
    trg_len torch.Size([2])
    labels torch.Size([2, 151])
guitar
    src torch.Size([2, 197])
    trg torch.Size([2, 189])
    src_len torch.Size([2])
    trg_len torch.Size([2])
    labels torch.Size([2, 189])
piano
    src torch.Size([1, 326])
    trg torch.Size([1, 60])
    src_len torch.Size([1])
    trg_len torch.Size([1])
    labels torch.Size([1, 60])


## model

In [17]:
import src.models.enc_dec
reload(src.models.enc_dec)
from src.models.enc_dec import EncoderDecoderPerformer

In [18]:
d_model = 256
n_vocab = len(const.all_tokens)
dropout = 0.1
config = {
    'lr' : 1e-4,
    'instruments' : ['piano', 'drums', 'guitar'],
    'embedding': {
        'd_model' : d_model,
        'positional_embedding' : 'relative',
        'n_vocab' : n_vocab,
        'dropout' : dropout,
        'max_len' : 10000
    },
    'encoder' : {
        'd_model' : d_model,
        'n_head' : 8,
        'd_inner' : 512,
        'dropout' : dropout,
        'n_layer' : 4
    },
    'decoder' : {
        'd_model' : d_model,
        'n_head' : 8,
        'd_inner' : 512,
        'dropout' : dropout,
        'n_layer' : 4
    },
    'head' : {
        'd_model' : d_model,
        'n_vocab' : n_vocab
    }
}

model = EncoderDecoderPerformer(config)
# model = BasePerformer.load_from_checkpoint(f'weights/{name}/last.ckpt', config=config)
model.count_parameters()

5632029

In [21]:
logits, loss = model('piano', **b['piano'])
loss

tensor(5.8758, grad_fn=<DivBackward0>)

## train

In [22]:
logger = TensorBoardLogger(save_dir='logs/', name=name)
lr_logger = LearningRateMonitor(logging_interval='step')
checkpoint = ModelCheckpoint(
    dirpath=f'weights/{name}/', 
    filename='{epoch}-{val_loss:.2f}', 
    monitor='train_loss',
    save_top_k=5, 
    period=1
)

trainer = Trainer(
    benchmark=True, 
    gpus=1, 
    accumulate_grad_batches=8,
    logger=logger, 
    max_epochs=30,
    callbacks=[checkpoint, lr_logger]
)

  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, tl, vl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params
-------------------------------------------------
0 | criterion | CrossEntropyLoss   | 0     
1 | embedding | RemiEmbedding      | 89.9 K
2 | encoder   | TransformerEncoder | 2.1 M 
3 | decoder   | TransformerDecoder | 3.2 M 
4 | heads     | ModuleDict         | 270 K 
-------------------------------------------------
5.6 M     Trainable params
0         Non-trainable params
5.6 M     Total params
22.528    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 42


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [24]:
trainer.save_checkpoint(f'weights/{name}/last.ckpt')

## generate

In [25]:
import src.models.enc_dec
reload(src.models.enc_dec)
from src.models.enc_dec import EncoderDecoderPerformer

gen_model = EncoderDecoderPerformer.load_from_checkpoint(f"weights/{name}/last.ckpt", config=config)

In [37]:
import random

path = data_config['data_dir']
files = os.listdir(path)
idx = random.randint(0, len(files))
file = files[idx]
print('idx:', idx, ' file:', file)
seq = MusicRepr.from_file(path + file, const=const).keep_instruments(['piano','drums', 'guitar'])
seq.get_instruments()

idx: 8024  file: 7669773e5ac1406e5f84a72d361dcc31.mid


['drums', 'guitar', 'piano']

In [45]:
trg_inst = 'drums'
prompt = MusicRepr.concatenate(seq.get_bars()[:20]).remove_instruments([trg_inst])
prompt.get_instruments(), len(prompt), len(prompt.to_remi())

(['guitar', 'piano'], 781, 2714)

In [46]:
res = gen_model.generate(trg_inst, seq=prompt, window=10, top_p=.9, t=.8)
print(len(res))

gen_seq = MusicRepr.from_indices(res, const=const)
len(gen_seq)

  0%|          | 0/20 [00:00<?, ?it/s]

1228


409

In [47]:
tracks = prompt.separate_tracks()
tracks[trg_inst] = gen_seq
final_seq = MusicRepr.merge_tracks(tracks)

save_path = f'assets/EncDec/{file[:-4]}/'
os.makedirs(save_path, exist_ok=True)
final_seq.to_midi(save_path + f'{trg_inst}_merge.mid')
gen_seq.to_midi(save_path + f'{trg_inst}_gen.mid')
prompt.to_midi(save_path + f'{trg_inst}_prompt.mid')

ticks per beat: 384
max tick: 30720
tempo changes: 1
time sig: 1
key sig: 0
markers: 21
lyrics: False
instruments: 2