In [1]:
import numpy as np
import torch
from torch import nn
import os
from tqdm.notebook import tqdm

from deepnote import MusicRepr, Constants
from importlib import reload

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

seed_everything(42)

  rank_zero_deprecation(
Global seed set to 42


42

## data

In [2]:
const = Constants(unit=4, num_tempo_bins=20, num_velocity_bins=20)

data_config = {
#     'data_dir' : '/home/soroosh/data/MIDI/pop909/train/',
#     'data_dir' : '/home/soroosh/data/MIDI/e-gmd-v1.0.0/midis_processed/',
    'data_dir' : '/home/soroosh/data/MIDI/lmd_processed/',
    'const' : const,
    'src_instruments' : ['piano', 'drums', 'guitar'],
    'trg_instruments' : ['piano', 'drums'],
    'max_files' : 100,
    'window_len' : 3,
    'pad_value' : 0,
    'n_jobs' : 20
}

name = 'small-lmd-win3'
print('model name:',name)

model name: small-lmd-win3


In [3]:
import src.data
reload(src.data)
from src.data import MidiDataset, get_dataloaders

dataset = MidiDataset(**data_config)
n = len(dataset)
n, len(dataset.lens)

  0%|          | 0/100 [00:00<?, ?it/s]

(7771, 83)

In [18]:
sample = dataset[100]
# for k in sample:
#     print(k, len(sample[k]))
for inst in sample:
    print(inst)
    for k in sample[inst]:
        print('  ',k, len(sample[inst][k]))

piano
   src 511
   trg 315
drums
   src 631
   trg 195


In [19]:
tl, vl = get_dataloaders(dataset, batch_size=2, n_jobs=2)

In [20]:
b = next(iter(vl))
for inst in b:
    print(inst)
    for k in b[inst]:
        print('   ', k, b[inst][k].shape)

drums
    src torch.Size([2, 313])
    trg torch.Size([2, 217])
    src_len torch.Size([2])
    trg_len torch.Size([2])
    labels torch.Size([2, 217])
piano
    src torch.Size([1, 306])
    trg torch.Size([1, 14])
    src_len torch.Size([1])
    trg_len torch.Size([1])
    labels torch.Size([1, 14])


## model

In [7]:
import src.modules.att
reload(src.modules.att)

import src.modules.decoder
reload(src.modules.decoder)

import src.modules
reload(src.modules)

import src.models.baseline
reload(src.models.baseline)
from src.models.baseline import BasePerformer

In [8]:
d_model = 256
n_vocab = len(const.all_tokens)
dropout = 0.1
config = {
    'lr' : 1e-4,
    'instruments' : ['piano', 'drums', 'guitar'],
    'embedding': {
        'd_model' : d_model,
        'positional_embedding' : 'relative',
        'n_vocab' : n_vocab,
        'dropout' : dropout,
        'max_len' : 10000
    },
    'decoder' : {
        'd_model' : d_model,
        'n_head' : 8,
        'd_inner' : 512,
        'dropout' : dropout,
        'n_layer' : 4
    },
    'head' : {
        'd_model' : d_model,
        'n_vocab' : n_vocab
    }
}

model = BasePerformer(config)
model.count_parameters()

3523613

In [9]:
# logits, loss = model('piano', **b['piano'])
# loss

## train

In [10]:
logger = TensorBoardLogger(save_dir='logs/', name=name)
lr_logger = LearningRateMonitor(logging_interval='step')
checkpoint = ModelCheckpoint(
    dirpath=f'weights/{name}/', 
    filename='{epoch}-{val_loss:.2f}', 
    monitor='train_loss',
    save_top_k=5, 
    period=1
)

trainer = Trainer(
    benchmark=True, 
    gpus=1, 
    accumulate_grad_batches=1,
    logger=logger, 
    max_epochs=10,
    callbacks=[checkpoint, lr_logger]
)

  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [22]:
trainer.fit(model, tl, vl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params
-------------------------------------------------
0 | criterion | CrossEntropyLoss   | 0     
1 | embedding | RemiEmbedding      | 89.9 K
2 | decoder   | TransformerDecoder | 3.2 M 
3 | heads     | ModuleDict         | 270 K 
-------------------------------------------------
3.5 M     Trainable params
0         Non-trainable params
3.5 M     Total params
14.094    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 42


Training: 3496it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceed

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [21]:
trainer.max_epochs = 20

In [23]:
trainer.save_checkpoint(f'weights/{name}/last.ckpt')

## generate

In [24]:
import src.modules.utils
reload(src.modules.utils)

import src.modules
reload(src.modules)

import src.models.baseline
reload(src.models.baseline)
from src.models.baseline import BasePerformer

gen_model = BasePerformer.load_from_checkpoint(f"weights/{name}/last.ckpt", config=config)

In [44]:
import random

path = data_config['data_dir']
idx = random.randint(0, 1000)
file = os.listdir(path)[idx]
print('idx:', idx, ' file:', file)
seq = MusicRepr.from_file(path + file, const=const).keep_instruments(['piano','drums', 'guitar'])
seq.get_instruments()

idx: 142  file: a4a962b91e744c86baf7c74d616a3ed7.mid


['piano', 'drums', 'guitar']

In [45]:
prompt = MusicRepr.concatenate(seq.get_bars()[:10]).remove_instruments(['drums'])
prompt.get_instruments(), len(prompt), len(prompt.to_remi())

(['piano', 'guitar'], 256, 826)

In [48]:
res = gen_model.generate('drums', seq=prompt, window=10, top_p=.9, t=.9)
print(len(res))

gen_seq = MusicRepr.from_indices(res, const=const)
len(gen_seq), gen_seq.get_bar_count()

  0%|          | 0/10 [00:00<?, ?it/s]

529


(182, 10)

In [49]:
gen_seq[:]

[Bar(position=0, tempo=115),
 Bar(position=0, tempo=115),
 Beat(position=2),
 Note(inst_family=drums, pitch=42, duration=1, velocity=60),
 Beat(position=4),
 Note(inst_family=drums, pitch=42, duration=1, velocity=54),
 Beat(position=6),
 Note(inst_family=drums, pitch=42, duration=1, velocity=54),
 Beat(position=8),
 Note(inst_family=drums, pitch=42, duration=1, velocity=54),
 Beat(position=10),
 Note(inst_family=drums, pitch=42, duration=1, velocity=34),
 Beat(position=12),
 Note(inst_family=drums, pitch=42, duration=1, velocity=60),
 Note(inst_family=drums, pitch=75, duration=1, velocity=67),
 Beat(position=14),
 Note(inst_family=drums, pitch=42, duration=1, velocity=47),
 Bar(position=0, chord=F#_M7),
 Note(inst_family=drums, pitch=42, duration=1, velocity=100),
 Beat(position=2),
 Note(inst_family=drums, pitch=36, duration=1, velocity=100),
 Note(inst_family=drums, pitch=42, duration=1, velocity=100),
 Beat(position=4),
 Note(inst_family=drums, pitch=38, duration=1, velocity=93),
 B

In [50]:
tracks = prompt.separate_tracks()
tracks['drums'] = gen_seq
final_seq = MusicRepr.merge_tracks(tracks)
final_seq.to_midi('test.mid')

ticks per beat: 120
max tick: 4830
tempo changes: 1
time sig: 1
key sig: 0
markers: 8
lyrics: False
instruments: 3

In [51]:
gen_seq.to_midi('gen.mid')

ticks per beat: 120
max tick: 4770
tempo changes: 2
time sig: 1
key sig: 0
markers: 18
lyrics: False
instruments: 1