In [1]:
import os
import torch
import random

from miditok import REMI
from miditok.utils import get_midi_programs
from miditoolkit import MidiFile
from pathlib import Path

os.environ['WANDB_DISABLED'] = 'true'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'

seed = random.randint(1000, 10000)
tokenizer = REMI()

In [2]:
from torchtoolkit.data import create_subsets
from utils.midi_dataset import MIDIDataset

#tokens = tokenizer.load_tokens(Path('/home/nico/data/ai/models/midi/tokens.json'))
params = tokenizer.load_params(
    Path('/home/nico/data/ai/models/midi/token_params.json'))

midi_dataset = MIDIDataset(
    files_paths=list(
        Path('/home/nico/data/ai/models/midi/').glob('ozzy*.json')),
    min_seq_len=24,
    max_seq_len=128
)
subset_train, subset_valid = create_subsets(midi_dataset, [0.3])

len(subset_valid)

Loading data: /home/nico/data/ai/models/midi: 100%|██████████| 9/9 [00:00<00:00, 287.69it/s]


62

In [3]:
len(tokenizer)

221

In [4]:
from transformers import GPT2LMHeadModel, GPT2Config, Trainer, TrainingArguments, GenerationConfig

# Creates model
config = GPT2Config(
    vocab_size=len(tokenizer),
    n_positions=2048,
    n_embd=512,
    n_layer=8,
    n_head=8,
    n_inner=2048,
    resid_pdrop=.1,
    embd_pdrop=.1,
    attn_pdrop=.1,
    padding_token_id=tokenizer['PAD_None'],
    bos_token_id=tokenizer['BOS_None'],
    eos_token_id=tokenizer['EOS_None']
)

model = GPT2LMHeadModel(config)
model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(221, 512)
    (wpe): Embedding(2048, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-7): 8 x GPT2Block(
        (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=512, out_features=221, bias=False)
)

In [5]:
from utils.midi_dataset import DataCollatorGen
from evaluate import load as load_metric
from torch import Tensor, argmax

metrics = {metric: load_metric(metric) for metric in ["accuracy"]}


def compute_metrics(eval_pred):
    """Computes metrics for pretraining.
    Must use proprocess_logits function that converts logits to predictions (argmax or sampling).

    :param eval_pred: EvalPrediction containing predictions and labels
    :return: metrics
    """
    predictions, labels = eval_pred
    not_pad_mask = labels != -100
    labels, predictions = labels[not_pad_mask], predictions[not_pad_mask]
    computed = metrics["accuracy"].compute(
        predictions=predictions.flatten(), references=labels.flatten())

    return computed


def preprocess_logits(logits: Tensor, _: Tensor) -> Tensor:
    """Preprocesses the logits before accumulating them during evaluation.
    This allows to significantly reduce the memory usage and make the training tractable.
    """
    pred_ids = argmax(logits, dim=-1)  # long dtype
    return pred_ids


training_config = TrainingArguments(
    "runs", False, True, True, False, "steps",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    eval_accumulation_steps=2,
    eval_steps=100,
    learning_rate=1e-4,
    weight_decay=0.01,
    max_grad_norm=3.0,
    max_steps=1000,
    lr_scheduler_type="cosine_with_restarts",
    warmup_ratio=0.3,
    log_level="debug",
    logging_strategy="steps",
    logging_steps=20,
    save_strategy="steps",
    metric_for_best_model='accuracy',
    greater_is_better=True,
    save_steps=100,
    save_total_limit=5,
    no_cuda=False,
    seed=seed,
    fp16=True,
    load_best_model_at_end=True,
    label_smoothing_factor=0.,
    optim="adamw_hf",
    gradient_checkpointing=True,
    dataloader_num_workers=2,
    auto_find_batch_size=True
)

trainer = Trainer(
    model=model,
    args=training_config,
    data_collator=DataCollatorGen(tokenizer["PAD_None"]),
    train_dataset=subset_train,
    eval_dataset=subset_valid,
    compute_metrics=compute_metrics,
    callbacks=None,
    preprocess_logits_for_metrics=preprocess_logits,
)

# Training
train_result = trainer.train()
trainer.save_model()  # Saves the tokenizer too
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
max_steps is given, it will override any value given in num_train_epochs
Using cuda_amp half precision backend
***** Running training *****
  Num examples = 145
  Num Epochs = 28
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 2
  Total optimization steps = 1,000
  Number of trainable parameters = 26,381,824


  0%|          | 0/1000 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


{'loss': 5.2089, 'learning_rate': 6.666666666666667e-06, 'epoch': 0.55}
{'loss': 3.9942, 'learning_rate': 1.3333333333333333e-05, 'epoch': 1.1}
{'loss': 3.2065, 'learning_rate': 2e-05, 'epoch': 1.64}
{'loss': 2.7077, 'learning_rate': 2.6666666666666667e-05, 'epoch': 2.19}


***** Running Evaluation *****
  Num examples = 62
  Batch size = 2


{'loss': 2.2786, 'learning_rate': 3.3333333333333335e-05, 'epoch': 2.74}


  0%|          | 0/31 [00:00<?, ?it/s]

Saving model checkpoint to runs/checkpoint-100
Configuration saved in runs/checkpoint-100/config.json
Configuration saved in runs/checkpoint-100/generation_config.json


{'eval_loss': 2.044809579849243, 'eval_accuracy': 0.0068319838056680165, 'eval_runtime': 1.489, 'eval_samples_per_second': 41.639, 'eval_steps_per_second': 20.819, 'epoch': 2.74}


Model weights saved in runs/checkpoint-100/pytorch_model.bin


{'loss': 2.0788, 'learning_rate': 4e-05, 'epoch': 3.29}
{'loss': 1.9711, 'learning_rate': 4.666666666666667e-05, 'epoch': 3.84}
{'loss': 1.8444, 'learning_rate': 5.333333333333333e-05, 'epoch': 4.38}
{'loss': 1.7867, 'learning_rate': 6e-05, 'epoch': 4.93}


***** Running Evaluation *****
  Num examples = 62
  Batch size = 2


{'loss': 1.6621, 'learning_rate': 6.666666666666667e-05, 'epoch': 5.48}


  0%|          | 0/31 [00:00<?, ?it/s]

Saving model checkpoint to runs/checkpoint-200
Configuration saved in runs/checkpoint-200/config.json
Configuration saved in runs/checkpoint-200/generation_config.json


{'eval_loss': 1.6472399234771729, 'eval_accuracy': 0.001771255060728745, 'eval_runtime': 1.4467, 'eval_samples_per_second': 42.856, 'eval_steps_per_second': 21.428, 'epoch': 5.48}


Model weights saved in runs/checkpoint-200/pytorch_model.bin


{'loss': 1.57, 'learning_rate': 7.333333333333333e-05, 'epoch': 6.03}
{'loss': 1.5773, 'learning_rate': 8e-05, 'epoch': 6.58}
{'loss': 1.5249, 'learning_rate': 8.666666666666667e-05, 'epoch': 7.12}
{'loss': 1.3723, 'learning_rate': 9.333333333333334e-05, 'epoch': 7.67}


***** Running Evaluation *****
  Num examples = 62
  Batch size = 2


{'loss': 1.4194, 'learning_rate': 0.0001, 'epoch': 8.22}


  0%|          | 0/31 [00:00<?, ?it/s]

Saving model checkpoint to runs/checkpoint-300
Configuration saved in runs/checkpoint-300/config.json
Configuration saved in runs/checkpoint-300/generation_config.json


{'eval_loss': 1.4831056594848633, 'eval_accuracy': 0.0036690283400809716, 'eval_runtime': 1.4, 'eval_samples_per_second': 44.285, 'eval_steps_per_second': 22.142, 'epoch': 8.22}


Model weights saved in runs/checkpoint-300/pytorch_model.bin


{'loss': 1.3373, 'learning_rate': 9.979871469976196e-05, 'epoch': 8.77}
{'loss': 1.3088, 'learning_rate': 9.919647942993148e-05, 'epoch': 9.32}
{'loss': 1.1787, 'learning_rate': 9.819814303479267e-05, 'epoch': 9.86}
{'loss': 1.0688, 'learning_rate': 9.681174353198687e-05, 'epoch': 10.41}


***** Running Evaluation *****
  Num examples = 62
  Batch size = 2


{'loss': 1.1103, 'learning_rate': 9.504844339512095e-05, 'epoch': 10.96}


  0%|          | 0/31 [00:00<?, ?it/s]

Saving model checkpoint to runs/checkpoint-400
Configuration saved in runs/checkpoint-400/config.json
Configuration saved in runs/checkpoint-400/generation_config.json


{'eval_loss': 1.3541311025619507, 'eval_accuracy': 0.005060728744939271, 'eval_runtime': 1.4476, 'eval_samples_per_second': 42.831, 'eval_steps_per_second': 21.415, 'epoch': 10.96}


Model weights saved in runs/checkpoint-400/pytorch_model.bin


{'loss': 0.9992, 'learning_rate': 9.292243968009331e-05, 'epoch': 11.51}
{'loss': 1.0008, 'learning_rate': 9.045084971874738e-05, 'epoch': 12.05}
{'loss': 0.8568, 'learning_rate': 8.765357330018056e-05, 'epoch': 12.6}
{'loss': 0.8321, 'learning_rate': 8.455313244934324e-05, 'epoch': 13.15}


***** Running Evaluation *****
  Num examples = 62
  Batch size = 2


{'loss': 0.8216, 'learning_rate': 8.117449009293668e-05, 'epoch': 13.7}


  0%|          | 0/31 [00:00<?, ?it/s]

Saving model checkpoint to runs/checkpoint-500
Configuration saved in runs/checkpoint-500/config.json
Configuration saved in runs/checkpoint-500/generation_config.json


{'eval_loss': 1.2271618843078613, 'eval_accuracy': 0.008097165991902834, 'eval_runtime': 1.4831, 'eval_samples_per_second': 41.803, 'eval_steps_per_second': 20.902, 'epoch': 13.7}


Model weights saved in runs/checkpoint-500/pytorch_model.bin


{'loss': 0.7279, 'learning_rate': 7.754484907260513e-05, 'epoch': 14.25}
{'loss': 0.6987, 'learning_rate': 7.369343312364993e-05, 'epoch': 14.79}
{'loss': 0.6533, 'learning_rate': 6.965125158269619e-05, 'epoch': 15.34}
{'loss': 0.6383, 'learning_rate': 6.545084971874738e-05, 'epoch': 15.89}


***** Running Evaluation *****
  Num examples = 62
  Batch size = 2


{'loss': 0.5569, 'learning_rate': 6.112604669781572e-05, 'epoch': 16.44}


  0%|          | 0/31 [00:00<?, ?it/s]

Saving model checkpoint to runs/checkpoint-600
Configuration saved in runs/checkpoint-600/config.json
Configuration saved in runs/checkpoint-600/generation_config.json


{'eval_loss': 1.1665607690811157, 'eval_accuracy': 0.005819838056680162, 'eval_runtime': 1.3023, 'eval_samples_per_second': 47.607, 'eval_steps_per_second': 23.803, 'epoch': 16.44}


Model weights saved in runs/checkpoint-600/pytorch_model.bin
Deleting older checkpoint [runs/checkpoint-100] due to args.save_total_limit


{'loss': 0.5219, 'learning_rate': 5.6711663290882776e-05, 'epoch': 16.99}
{'loss': 0.4473, 'learning_rate': 5.2243241517525754e-05, 'epoch': 17.53}
{'loss': 0.4815, 'learning_rate': 4.775675848247427e-05, 'epoch': 18.08}
{'loss': 0.4056, 'learning_rate': 4.328833670911724e-05, 'epoch': 18.63}


***** Running Evaluation *****
  Num examples = 62
  Batch size = 2


{'loss': 0.4161, 'learning_rate': 3.887395330218429e-05, 'epoch': 19.18}


  0%|          | 0/31 [00:00<?, ?it/s]

Saving model checkpoint to runs/checkpoint-700
Configuration saved in runs/checkpoint-700/config.json
Configuration saved in runs/checkpoint-700/generation_config.json


{'eval_loss': 1.1532626152038574, 'eval_accuracy': 0.010880566801619434, 'eval_runtime': 1.4545, 'eval_samples_per_second': 42.627, 'eval_steps_per_second': 21.313, 'epoch': 19.18}


Model weights saved in runs/checkpoint-700/pytorch_model.bin
Deleting older checkpoint [runs/checkpoint-200] due to args.save_total_limit


{'loss': 0.3645, 'learning_rate': 3.4549150281252636e-05, 'epoch': 19.73}
{'loss': 0.3144, 'learning_rate': 3.0348748417303823e-05, 'epoch': 20.27}
{'loss': 0.3242, 'learning_rate': 2.630656687635007e-05, 'epoch': 20.82}
{'loss': 0.2809, 'learning_rate': 2.245515092739488e-05, 'epoch': 21.37}


***** Running Evaluation *****
  Num examples = 62
  Batch size = 2


{'loss': 0.2769, 'learning_rate': 1.8825509907063327e-05, 'epoch': 21.92}


  0%|          | 0/31 [00:00<?, ?it/s]

Saving model checkpoint to runs/checkpoint-800
Configuration saved in runs/checkpoint-800/config.json
Configuration saved in runs/checkpoint-800/generation_config.json


{'eval_loss': 1.1487267017364502, 'eval_accuracy': 0.008223684210526315, 'eval_runtime': 1.5001, 'eval_samples_per_second': 41.332, 'eval_steps_per_second': 20.666, 'epoch': 21.92}


Model weights saved in runs/checkpoint-800/pytorch_model.bin
Deleting older checkpoint [runs/checkpoint-300] due to args.save_total_limit


{'loss': 0.2538, 'learning_rate': 1.544686755065677e-05, 'epoch': 22.47}
{'loss': 0.2584, 'learning_rate': 1.2346426699819458e-05, 'epoch': 23.01}
{'loss': 0.2249, 'learning_rate': 9.549150281252633e-06, 'epoch': 23.56}
{'loss': 0.2179, 'learning_rate': 7.077560319906695e-06, 'epoch': 24.11}


***** Running Evaluation *****
  Num examples = 62
  Batch size = 2


{'loss': 0.2109, 'learning_rate': 4.951556604879048e-06, 'epoch': 24.66}


  0%|          | 0/31 [00:00<?, ?it/s]

Saving model checkpoint to runs/checkpoint-900
Configuration saved in runs/checkpoint-900/config.json
Configuration saved in runs/checkpoint-900/generation_config.json


{'eval_loss': 1.1691257953643799, 'eval_accuracy': 0.008350202429149798, 'eval_runtime': 1.4726, 'eval_samples_per_second': 42.104, 'eval_steps_per_second': 21.052, 'epoch': 24.66}


Model weights saved in runs/checkpoint-900/pytorch_model.bin
Deleting older checkpoint [runs/checkpoint-400] due to args.save_total_limit


{'loss': 0.221, 'learning_rate': 3.18825646801314e-06, 'epoch': 25.21}
{'loss': 0.2168, 'learning_rate': 1.8018569652073381e-06, 'epoch': 25.75}
{'loss': 0.2076, 'learning_rate': 8.035205700685167e-07, 'epoch': 26.3}
{'loss': 0.184, 'learning_rate': 2.012853002380466e-07, 'epoch': 26.85}


***** Running Evaluation *****
  Num examples = 62
  Batch size = 2


{'loss': 0.2121, 'learning_rate': 0.0, 'epoch': 27.4}


  0%|          | 0/31 [00:00<?, ?it/s]

Saving model checkpoint to runs/checkpoint-1000
Configuration saved in runs/checkpoint-1000/config.json
Configuration saved in runs/checkpoint-1000/generation_config.json


{'eval_loss': 1.173716425895691, 'eval_accuracy': 0.00847672064777328, 'eval_runtime': 1.4361, 'eval_samples_per_second': 43.174, 'eval_steps_per_second': 21.587, 'epoch': 27.4}


Model weights saved in runs/checkpoint-1000/pytorch_model.bin
Deleting older checkpoint [runs/checkpoint-500] due to args.save_total_limit


Training completed. Do not forget to share your model on huggingface.co/models =)


Loading best model from runs/checkpoint-700 (score: 0.010880566801619434).
Saving model checkpoint to runs
Configuration saved in runs/config.json
Configuration saved in runs/generation_config.json


{'train_runtime': 304.041, 'train_samples_per_second': 13.156, 'train_steps_per_second': 3.289, 'train_loss': 1.0806627798080444, 'epoch': 27.4}


Model weights saved in runs/pytorch_model.bin


***** train metrics *****
  epoch                    =       27.4
  train_loss               =     1.0807
  train_runtime            = 0:05:04.04
  train_samples_per_second =     13.156
  train_steps_per_second   =      3.289


In [9]:
import json

from typing import Any, Dict, List
from torch import LongTensor, flip, cat, full
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
from copy import deepcopy
from transformers import GPT2LMHeadModel, GPT2Config, Trainer, TrainingArguments, GenerationConfig

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = GPT2LMHeadModel.from_pretrained('./runs/')
model = model.to(device)

def collate_gen_left(batch: List[Dict[str, LongTensor]]) -> LongTensor:
    # Here the sequences are padded to the left, so that the last token along the time dimension
    # is always the last token of each seq, allowing to efficiently generate by batch
    bos_shape = (1,)
    batch = [flip(cat([full(bos_shape, tokenizer["BOS_None"]),
                  seq["input_ids"]], dim=0), dims=(0,)) for seq in batch]
    batch = pad_sequence(batch, batch_first=True,
                         padding_value=tokenizer["PAD_None"])  # (N,T) or (N,T,Z)
    batch = flip(batch, dims=(1,)).long()
    return batch  # (N,T)


generation_config = GenerationConfig(
    min_new_tokens=512,
    max_new_tokens=1024,
    repetition_penalty=1.5,
    num_beams=24,        # no beam search
    early_stopping=True,
    no_repeat_ngram_size=5,
    num_return_sequences=16,
    length_penalty=0.5,
    num_beam_groups=4,
    diversity_penalty=2.0,
    do_sample=False,     # but sample instead
    temperature=0.75,
    top_k=35,
    top_p=0.35,
    epsilon_cutoff=3e-4,
    eta_cutoff=1e-3,
    pad_token_id=config.padding_token_id,
)

(gen_results_path := Path('gen_res')).mkdir(parents=True, exist_ok=True)

def rec_gen(tokens):
    global model
    global generation_config

    res = model.generate(torch.LongTensor([tokens]).to(model.device),
                         generation_config=generation_config)
    
    out = res[0].cpu().numpy().tolist()
    new_tokens = out[len(tokens)-1:]

    print(f'Generated {len(new_tokens)} new tokens.')

    return new_tokens


max_iter = 2
iter_count = 0
init_size = 256

with open('/home/nico/data/ai/models/midi/ozzy_osbourne-facing_hell.json') as tokens_file:
    ids = json.load(tokens_file)['ids']
    tokens = ids[0][:init_size] # 1 channel only

    while iter_count < max_iter:
        block_size = init_size if iter_count == 0 else int(init_size / 128)
        tokens += rec_gen(tokens[-block_size:])

        iter_count += 1

loading configuration file ./runs/config.json
Model config GPT2Config {
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 1,
  "embd_pdrop": 0.1,
  "eos_token_id": 2,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_embd": 512,
  "n_head": 8,
  "n_inner": 2048,
  "n_layer": 8,
  "n_positions": 2048,
  "padding_token_id": 0,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "torch_dtype": "float32",
  "transformers_version": "4.29.0.dev0",
  "use_cache": true,
  "vocab_size": 221
}

loading weights file ./runs/pytorch_model.bin
Generate config GenerationConfig {
  "_from_model_config": true,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "tran

In [7]:

print('Generating the midi...')

midi = tokenizer.tokens_to_midi(torch.LongTensor([tokens]).cpu(), time_division=384)
# midi.instruments[0].name = f'Continuation of original sample ({len(generated)} tokens)'
# midi.instruments[1].name = f'Original sample ({len(prompt)} tokens)'
# midi.instruments[2].name = f'Original sample and continuation'
midi.dump(gen_results_path / 'full.mid')
# tokenizer.save_tokens(tokens, gen_results_path / f'{count}.json')

type(midi)

Generating the midi...


miditoolkit.midi.parser.MidiFile