In [1]:
from omegaconf import OmegaConf
from deepnote import MusicRepr
from importlib import reload

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

seed_everything(42)

42

## Config

In [4]:
conf = OmegaConf.load('conf.yaml')

## Dataset

In [6]:
from midi_transformer.data import LMDataset, get_dataloaders


dataset = LMDataset(
    **conf['data']
)
train_loader, val_loader = get_dataloaders(dataset, batch_size=2, n_jobs=2, val_frac=0.1)

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))


train dataset has 47499 samples and val dataset has 5277 samples.


In [4]:
x, y = dataset[0]
x.shape, y.shape

((512, 8), (512, 8))

In [5]:
for b in val_loader:
    for k in b:
        print(k, b[k].shape)
    break

X torch.Size([2, 512, 8])
X_len torch.Size([2])
labels torch.Size([2, 512, 8])


## Model

In [8]:
from src.model import CPTransformer

model = CPTransformer(conf['model'])
print('model has', model.count_parameters(), 'parameters.')
model

model has 17322172 parameters.


CPTransformer(
  (criterion): CrossEntropyLoss()
  (emb): CPEmbedding(
    (emb_layers): ModuleDict(
      (ttype): Embeddings(
        (lut): Embedding(2, 8)
      )
      (position): Embeddings(
        (lut): Embedding(49, 64)
      )
      (tempo): Embeddings(
        (lut): Embedding(21, 48)
      )
      (chord): Embeddings(
        (lut): Embedding(133, 96)
      )
      (instrument): Embeddings(
        (lut): Embedding(17, 32)
      )
      (pitch): Embeddings(
        (lut): Embedding(128, 96)
      )
      (duration): Embeddings(
        (lut): Embedding(48, 64)
      )
      (velocity): Embeddings(
        (lut): Embedding(30, 64)
      )
    )
    (proj): Linear(in_features=472, out_features=512, bias=True)
    (pos_emb): PositionalEncoding()
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): Encoder(
    (layers): ModuleList(
      (0): EncoderBlock(
        (self_attention): AttentionLayer(
          (inner_attention): LinearAttention(
            (feature_map

In [7]:
model.step(b)

tensor(3.5755, grad_fn=<DivBackward0>)

## Trainer

In [7]:
name = '-'.join(conf['data']['instruments'])
print('model name:',name)

logger = TensorBoardLogger(save_dir='logs/', name=name)
lr_logger = LearningRateMonitor(logging_interval='step')
checkpoint = ModelCheckpoint(dirpath=f'weights/{name}/', 
                             filename='{epoch}-{val_loss:.2f}', 
                             monitor='val_loss',
                             save_top_k=10, 
                             period=1)

trainer = Trainer(benchmark=True, 
                  gpus=0, 
                  reload_dataloaders_every_epoch=True,
#                   gradient_clip_val=0.5,
                  accumulate_grad_batches=2,
                  logger=logger, 
                  max_epochs=conf['model']['max_epochs'],
                  callbacks=[checkpoint, lr_logger])


GPU available: False, used: False
TPU available: False, using: 0 TPU cores


model name: piano-brass-drums


In [8]:
trainer.fit(model, train_loader, val_loader)


  | Name      | Type             | Params
-----------------------------------------------
0 | criterion | CrossEntropyLoss | 0     
1 | emb       | CPEmbedding      | 276 K 
2 | encoder   | Encoder          | 16.8 M
3 | head      | CPHeadLayer      | 222 K 
I0509 03:10:19.681107 4404151744 lightning.py:1459] 
  | Name      | Type             | Params
-----------------------------------------------
0 | criterion | CrossEntropyLoss | 0     
1 | emb       | CPEmbedding      | 276 K 
2 | encoder   | Encoder          | 16.8 M
3 | head      | CPHeadLayer      | 222 K 


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]



MisconfigurationException: ModelCheckpoint(monitor='val_loss') not found in the returned metrics: ['train_ttype', 'train_position', 'train_tempo', 'train_chord', 'train_instrument', 'train_pitch', 'train_duration', 'train_velocity', 'train_loss']. HINT: Did you call self.log('val_loss', tensor) in the LightningModule?

## generation

In [9]:
gen = model.generate(
    prompt=None,
    max_len=10, 
    temperatures={
        'ttype' : 1.5,
        'position': 0.9,
        'tempo': 0.5,
        'chord': 0.8,
        'instrument': 1.,
        'pitch': 1.1,
        'duration': 0.8,
        'velocity': 0.8
    }
)
seq = MusicRepr.from_cp(gen)
seq.to_midi('assets/gen.mid')

  0%|          | 0/10 [00:00<?, ?it/s]

  return torch.multinomial(F.softmax(x.cpu().detach()), num_samples=1, replacement=True)


ticks per beat: 120
max tick: 510
tempo changes: 0
time sig: 1
key sig: 0
markers: 3
lyrics: False
instruments: 4