In [1]:
import numpy as np
import torch
from torch import nn
import os
from tqdm.notebook import tqdm
import pickle
from deepnote import MusicRepr, Constants
from importlib import reload

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

seed_everything(42)

  rank_zero_deprecation(
Global seed set to 42


42

## data

In [2]:
const = Constants(unit=4, num_tempo_bins=20, num_velocity_bins=20)
files = pickle.load(open('files.pkl', 'rb'))

data_config = {
#     'data_dir' : '/home/soroosh/data/MIDI/pop909/train/',
#     'data_dir' : '/home/soroosh/data/MIDI/e-gmd-v1.0.0/midis_processed/',
    'data_dir' : '/home/soroosh/data/MIDI/lmd_processed/',
    'files' : files,
    'const' : const,
    'instruments' : ['piano', 'drums', 'guitar', 'bass'],
#     'src_instruments' : ['piano', 'drums', 'guitar'],
#     'trg_instruments' : ['piano', 'drums', 'guitar'],
    'max_files' : 1000,
    'window_len' : 2,
    'max_len' : 1536,
    'pad_value' : 0,
    'n_jobs' : 20
}

name = 'mix-lmd-win2-piano-guitar-drums-bass'
print('model name:',name)

model name: mix-lmd-win2-piano-guitar-drums-bass


In [3]:
import src.data.multi
reload(src.data.multi)

import src.data
reload(src.data)

from src.data import MultiTrackDataset, get_dataloaders

dataset = MultiTrackDataset(**data_config)
n = len(dataset)
n, len(dataset.lens)

  0%|          | 0/1000 [00:00<?, ?it/s]

(47770, 1000)

In [4]:
sample = dataset[100]
# for k in sample:
#     print(k, len(sample[k]))
for inst in sample:
    print(inst, len(sample[inst]))

piano 13
drums 277
guitar 19
bass 46


In [5]:
tl, vl = get_dataloaders(dataset, batch_size=2, n_jobs=2)

In [6]:
b = next(iter(vl))
for inst in b:
    print(inst)
    for k in b[inst]:
        print('   ', k, b[inst][k].shape)

piano
    X torch.Size([2, 161])
    X_len torch.Size([2])
    labels torch.Size([2, 161])
drums
    X torch.Size([2, 168])
    X_len torch.Size([2])
    labels torch.Size([2, 168])
guitar
    X torch.Size([2, 328])
    X_len torch.Size([2])
    labels torch.Size([2, 328])
bass
    X torch.Size([2, 75])
    X_len torch.Size([2])
    labels torch.Size([2, 75])


## model

In [7]:
import src.models.mix_enc_dec
reload(src.models.mix_enc_dec)
from src.models.mix_enc_dec import EncoderMixDecoderPerformer

In [8]:
d_model = 256
n_vocab = len(const.all_tokens)
dropout = 0.1
config = {
    'lr' : 1e-4,
    'instruments' : ['piano', 'drums', 'guitar', 'bass'],
    'embedding': {
        'd_model' : d_model,
        'positional_embedding' : 'relative',
        'n_vocab' : n_vocab,
        'dropout' : dropout,
        'max_len' : 10000
    },
    'encoder' : {
        'd_model' : d_model,
        'n_head' : 8,
        'd_inner' : 512,
        'dropout' : dropout,
        'n_layer' : 4
    },
    'decoder' : {
        'd_model' : d_model,
        'n_head' : 8,
        'd_inner' : 512,
        'dropout' : dropout,
        'n_layer' : 4
    },
    'head' : {
        'd_model' : d_model,
        'n_vocab' : n_vocab
    }
}

model = EncoderMixDecoderPerformer(config)
# model = BasePerformer.load_from_checkpoint(f'weights/{name}/last.ckpt', config=config)
model.count_parameters()

5722236

In [9]:
logits, loss = model('piano', b)
loss

tensor(5.9022, grad_fn=<DivBackward0>)

## train

In [10]:
logger = TensorBoardLogger(save_dir='logs/', name=name)
lr_logger = LearningRateMonitor(logging_interval='step')
checkpoint = ModelCheckpoint(
    dirpath=f'weights/{name}/', 
    filename='{epoch}-{val_loss:.2f}', 
    monitor='train_loss',
    save_top_k=5, 
    period=1
)

trainer = Trainer(
    benchmark=True, 
    gpus=1, 
    accumulate_grad_batches=8,
    logger=logger, 
    max_epochs=30,
    callbacks=[checkpoint, lr_logger]
)

  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, tl, vl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                  | Params
----------------------------------------------------
0 | criterion | CrossEntropyLoss      | 0     
1 | embedding | RemiEmbedding         | 89.9 K
2 | encoder   | TransformerEncoder    | 2.1 M 
3 | decoder   | TransformerMixDecoder | 3.2 M 
4 | heads     | ModuleDict            | 360 K 
----------------------------------------------------
5.7 M     Trainable params
0         Non-trainable params
5.7 M     Total params
22.889    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 42


Training: 0it [00:00, ?it/s]

In [None]:
trainer.save_checkpoint(f'weights/{name}/last.ckpt')

## generate

In [None]:
import src.models.mix_enc_dec
reload(src.models.mix_enc_dec)
from src.models.mix_enc_dec import EncoderMixDecoderPerformer

gen_model = EncoderDecoderPerformer.load_from_checkpoint(f"weights/{name}/last.ckpt", config=config)

In [None]:
import random

path = data_config['data_dir']
files = os.listdir(path)
idx = random.randint(0, len(files))
file = files[idx]
print('idx:', idx, ' file:', file)
seq = MusicRepr.from_file(path + file, const=const).keep_instruments(['piano','drums', 'guitar'])
seq.get_instruments()

In [None]:
trg_inst = 'drums'
prompt = MusicRepr.concatenate(seq.get_bars()[:20]).remove_instruments([trg_inst])
prompt.get_instruments(), len(prompt), len(prompt.to_remi())

In [None]:
res = gen_model.generate(trg_inst, seq=prompt, window=10, top_p=.9, t=.8)
print(len(res))

gen_seq = MusicRepr.from_indices(res, const=const)
len(gen_seq)

In [None]:
tracks = prompt.separate_tracks()
tracks[trg_inst] = gen_seq
final_seq = MusicRepr.merge_tracks(tracks)

save_path = f'assets/EncDec/{file[:-4]}/'
os.makedirs(save_path, exist_ok=True)
final_seq.to_midi(save_path + f'{trg_inst}_merge.mid')
gen_seq.to_midi(save_path + f'{trg_inst}_gen.mid')
prompt.to_midi(save_path + f'{trg_inst}_prompt.mid')