In [1]:
import numpy as np
import torch
import os
from tqdm.notebook import tqdm

from deepnote import MusicRepr, Constants
from importlib import reload

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

seed_everything(42)

  rank_zero_deprecation(
Global seed set to 42


42

## data

In [2]:
const = Constants(unit=4, num_tempo_bins=20, num_velocity_bins=20)

data_config = {
#     'data_dir' : '/home/soroosh/data/MIDI/pop909/train/',
#     'data_dir' : '/home/soroosh/data/MIDI/e-gmd-v1.0.0/midis_processed/',
    'data_dir' : '/home/soroosh/data/MIDI/lmd_processed/',
    'const' : const,
    'instruments' : ['piano', 'drums', 'guitar'],
    'max_files' : 10,
    'window_len' : 5,
    'pad_value' : 0,
    'n_jobs' : 20
}

name = 'small-lmd-win10'
print('model name:',name)

model name: small-lmd-win10


In [3]:
import src.data
reload(src.data)
from src.data import MidiDataset, get_dataloaders

dataset = MidiDataset(**data_config)
n = len(dataset)
n

  0%|          | 0/10 [00:00<?, ?it/s]

914

In [4]:
sample = dataset[100]
# for k in sample:
#     print(k, len(sample[k]))
for inst in sample:
    print(inst)
    for k in sample[inst]:
        print('  ',k, len(sample[inst][k]))

piano
   src 684
   trg 146
drums
   src 421
   trg 441
guitar
   src 653
   trg 200


In [5]:
tl, vl = get_dataloaders(dataset, batch_size=4, n_jobs=2)

In [6]:
b = next(iter(vl))
for inst in b:
    print(inst)
    for k in b[inst]:
        print('   ', k, b[inst][k].shape)

piano
    src torch.Size([2, 1462])
    trg torch.Size([2, 225])
    src_len torch.Size([2])
    trg_len torch.Size([2])
    labels torch.Size([2, 225])
drums
    src torch.Size([4, 1695])
    trg torch.Size([4, 643])
    src_len torch.Size([4])
    trg_len torch.Size([4])
    labels torch.Size([4, 643])
guitar
    src torch.Size([4, 1551])
    trg torch.Size([4, 780])
    src_len torch.Size([4])
    trg_len torch.Size([4])
    labels torch.Size([4, 780])


## model

In [7]:
import src.modules.decoder
reload(src.modules.decoder)
import src.modules
reload(src.modules)

import src.models.baseline
reload(src.models.baseline)
from src.models.baseline import BasePerformer

In [8]:
d_model = 256
n_vocab = len(const.all_tokens)
dropout = 0.1
config = {
    'lr' : 1e-4,
    'instruments' : ['piano', 'drums', 'guitar'],
    'embedding': {
        'd_model' : d_model,
        'positional_embedding' : 'relative',
        'n_vocab' : n_vocab,
        'dropout' : dropout,
        'max_len' : 10000
    },
    'decoder' : {
        'd_model' : d_model,
        'n_head' : 8,
        'd_inner' : 512,
        'dropout' : dropout,
        'n_layer' : 1,
        'share_weights' : False
    },
    'head' : {
        'd_model' : d_model,
        'n_vocab' : n_vocab
    }
}

model = BasePerformer(config)
model.count_parameters()

1151261

In [9]:
logits, loss = model.forward(
    inst='piano', 
    src=b['piano']['src'], 
    src_len=b['piano']['src_len'], 
    trg=b['piano']['trg'], 
    trg_len=b['piano']['trg_len'], 
    labels=b['piano']['labels']
)
loss

tensor(6.1576, grad_fn=<DivBackward0>)

## train

In [10]:
logger = TensorBoardLogger(save_dir='logs/', name=name)
lr_logger = LearningRateMonitor(logging_interval='step')
checkpoint = ModelCheckpoint(
    dirpath=f'weights/{name}/', 
    filename='{epoch}-{val_loss:.2f}', 
    monitor='train_loss',
    save_top_k=5, 
    period=1
)

trainer = Trainer(
    benchmark=True, 
    gpus=1, 
    accumulate_grad_batches=1,
    logger=logger, 
    max_epochs=20,
    callbacks=[checkpoint, lr_logger]
)

  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [11]:
trainer.fit(model, tl, vl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params
-------------------------------------------------
0 | criterion | CrossEntropyLoss   | 0     
1 | embedding | RemiEmbedding      | 89.9 K
2 | decoder   | TransformerDecoder | 790 K 
3 | heads     | ModuleDict         | 270 K 
-------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.605     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 42


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [12]:
trainer.save_checkpoint(f'weights/{name}/last.ckpt')

## generate