# Piano Playalong Generation from MIDI

## Description

Add Description

## Tokenization

We use the MIDITok Tokenizer to create Tokens from our MIDI files:

https://miditok.readthedocs.io/




### Imports

In [2]:
import numpy as np
import pandas as pd
import os
import json
from pathlib import Path
from tqdm import tqdm
import random
import torch
from torch import nn, optim
from torch.utils.data import Dataset, ConcatDataset, DataLoader, random_split
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import lightning as L
# necessary cuda version on your system
print(torch.version.cuda)
# seed random module
random.seed(42)
#! pip install partitura miditok transformers[torch]

import partitura as pt
from miditok import Structured, TokenizerConfig
from miditok.pytorch_data import DatasetMIDI, DataCollator, split_midis_for_training


None


### Create Tokenizer
Using the "Structured"-Tokenizer from MidiTok

In [3]:
# parameters
TOKENIZER_PARAMS = {
    "pitch_range": (21, 109),
    "beat_res": {(0, 4): 8, (4, 12): 4},
    "num_velocities": 32,
    "special_tokens": ["PAD", "BOS", "EOS", "MASK"],
    "use_chords": False,
    "use_rests": False,
    "use_tempos": False,
    "use_time_signatures": False,
    "use_programs": False,
    "num_tempos": 32,  # number of tempo bins
    "tempo_range": (40, 200),  # (min, max)
}
# Set to True if you want to use BPE
USE_BPE = False

config = TokenizerConfig(**TOKENIZER_PARAMS)

# Creates the tokenizer
tokenizer = Structured(config)


### Load data and tokenize
Here you can filter out Midi-files you want to exclude (e.g. Tuning Tracks - Use lookup.json file to inspect data)

In [4]:
midi_paths = list(Path("data").glob("**/*.mid"))
lookup_path = Path("data", "lookup.json")

# files to exclude:
idx_del = []
lookup = json.loads(lookup_path.read_text())
midi_paths_cleaned = [] 
# files containing this string will be excluded
lookup_str = "tuning"
for key, title in lookup.items():
    if lookup_str in title.lower():
        idx_del.append(key)
    # manually specify keys to exclude 
    elif str(key) in ["0021","0361","0362"]:   # verbal instructions and whole cds that couldnt be converted to midi (too long)
        idx_del.append(key)
    else:
        midi_paths_cleaned.append(Path("data", str(key)+".mid"))

print(f"Loaded {len(midi_paths_cleaned)} valid files, {len(idx_del)} invalid files excluded.")


Loaded 1260 valid files, 55 invalid files excluded.


In [5]:

tokenizer_path = Path("data", "tokenizer", "tokenizer.json")

# load tokenizer if it already exists
if os.path.exists(tokenizer_path):
    tokenizer = Structured(params=tokenizer_path)

else:

    
    # Builds the vocabulary with BPE
    if USE_BPE:
        print(f"Learning BPE...")
        tokenizer.learn_bpe(vocab_size=30000, files_paths=midi_paths_cleaned)
        print(f"Saving tokenizer with BPE to {tokenizer_path}")
        tokenizer.save_params(tokenizer_path)
        print("Finished.")
    # Saves tokenizer without BPE
    else:
        print(f"Saving tokenizer to {tokenizer_path}")
        tokenizer.save_params(tokenizer_path)
        print("Done.")


In [6]:
midi_paths_cleaned[0]

PosixPath('data/0001.mid')

In [7]:
from partitura import load_performance_midi
score = load_performance_midi(midi_paths_cleaned[0])
# Get the duration of the MIDI file in seconds

print("Number of notes in the MIDI file:", len(score.note_array()))
score.note_array()[-1]

Number of notes in the MIDI file: 2188


(219.03645, 3.1432292, 168220, 2414, 71, 28, 0, 0, 'n2187')

In [8]:
midi = tokenizer(midi_paths_cleaned[0])
len(midi[0])

8752

### Split MIDIs into subsequences

In [13]:
# Split MIDIs into smaller chunks for training
MAX_SEQUENCE_LENGTH = 128
dataset_chunks_dir = Path("data", "midi_chunks")

if not os.path.exists(dataset_chunks_dir) or not os.listdir(dataset_chunks_dir):
    midi_paths_chunks = split_midis_for_training(
        files_paths=midi_paths_cleaned,
        tokenizer=tokenizer,
        save_dir=dataset_chunks_dir,
        max_seq_len=MAX_SEQUENCE_LENGTH,
    )
else: 
    midi_paths_chunks = [Path(p) for p in dataset_chunks_dir.iterdir() if p.is_file()][1:] #first object is some hidden file

In [16]:
print(f"Total number of files after splitting into chunks: ",len(os.listdir(dataset_chunks_dir)))
# create train and validation-set 
random.shuffle(midi_paths_chunks)
val_chunks_paths = sorted(midi_paths_chunks[0:len(midi_paths_chunks)//5])
train_chunks_paths = sorted(midi_paths_chunks[len(midi_paths_chunks)//5:])

print(f"Size after splitting into train/val : {len(train_chunks_paths)} / {len(val_chunks_paths)}")

Total number of files after splitting into chunks:  148667
Size after splitting into train/val : 118933 / 29733


### Dataloading and Collator

In [17]:
# Load midi chunks into dataset
dataset_train = DatasetMIDI(
    files_paths=train_chunks_paths,
    max_seq_len=MAX_SEQUENCE_LENGTH,
    tokenizer=tokenizer,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"]
)

dataset_val = DatasetMIDI(
    files_paths=val_chunks_paths,
    max_seq_len=MAX_SEQUENCE_LENGTH,
    tokenizer=tokenizer,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"]
)
# setup Collator
collator = DataCollator(
    tokenizer["PAD_None"], pad_on_left=True, copy_inputs_as_labels=True , shift_labels=True)

# Set up dataloader
data_loader_train = DataLoader(dataset=dataset_train, batch_size=64, collate_fn=collator,shuffle=True)
data_loader_val = DataLoader(dataset=dataset_val, batch_size=64, collate_fn=collator,shuffle=False)
print(f"Dataloader created.")
print(f"N samples in train/val : {len(data_loader_train)*64} / {len(data_loader_val)*64}")

Dataloader created.
N samples in train/val : 118976 / 29760


In [18]:
# Inspect elements in batch
first_batch = next(iter(data_loader_train))
first_batch


{'input_ids': tensor([[192,  41, 110,  ..., 188,  49, 109],
         [189,  42, 106,  ..., 188,  40, 110],
         [196,  43, 108,  ..., 188,  53, 111],
         ...,
         [190,  46, 109,  ..., 189,  46, 101],
         [192,  52, 110,  ..., 188,  37, 106],
         [192,  17, 102,  ..., 188,  40, 107]]),
 'labels': tensor([[ 41, 110, 124,  ...,  49, 109, 139],
         [ 42, 106, 124,  ...,  40, 110, 124],
         [ 43, 108, 131,  ...,  53, 111, 135],
         ...,
         [ 46, 109, 126,  ...,  46, 101, 125],
         [ 52, 110, 124,  ...,  37, 106, 138],
         [ 17, 102, 128,  ...,  40, 107, 135]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         ...,
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1]], dtype=torch.int32)}

## Training the model

### Transformer model class

In [12]:
from torch.nn import functional as F
from transformers import GPT2LMHeadModel, AutoConfig, Trainer, TrainingArguments
from transformers.optimization import AdamW 

# Setting the seed
L.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)


Seed set to 42


Device: cpu


In [30]:
# load config and setup untrained model of gpt2
config = AutoConfig.from_pretrained("gpt2",vocab_size = len(tokenizer),
                                    n_ctx=MAX_SEQUENCE_LENGTH, 
                                    bos_token_id = tokenizer["BOS_None"], 
                                    eos_token_id = tokenizer["EOS_None"],
                                    )

model = GPT2LMHeadModel(config)
"Number of Parameters in model:", sum(p.numel() for p in model.parameters()),model.config



('Number of Parameters in model:',
 86083584,
 GPT2Config {
   "_name_or_path": "gpt2",
   "activation_function": "gelu_new",
   "architectures": [
     "GPT2LMHeadModel"
   ],
   "attn_pdrop": 0.1,
   "bos_token_id": 1,
   "embd_pdrop": 0.1,
   "eos_token_id": 2,
   "initializer_range": 0.02,
   "layer_norm_epsilon": 1e-05,
   "model_type": "gpt2",
   "n_ctx": 128,
   "n_embd": 768,
   "n_head": 12,
   "n_inner": null,
   "n_layer": 12,
   "n_positions": 1024,
   "reorder_and_upcast_attn": false,
   "resid_pdrop": 0.1,
   "scale_attn_by_inverse_layer_idx": false,
   "scale_attn_weights": true,
   "summary_activation": null,
   "summary_first_dropout": 0.1,
   "summary_proj_to_labels": true,
   "summary_type": "cls_index",
   "summary_use_proj": true,
   "task_specific_params": {
     "text-generation": {
       "do_sample": true,
       "max_length": 50
     }
   },
   "transformers_version": "4.39.3",
   "use_cache": true,
   "vocab_size": 314
 })

In [35]:
from lightning.pytorch.callbacks import ModelCheckpoint
# set training args
training_args = TrainingArguments(
    output_dir = "./model/gpt-2",
    evaluation_strategy = "epoch",
    auto_find_batch_size=True,
    num_train_epochs=4,
    gradient_accumulation_steps=8,
    weight_decay=0.1,
    lr_scheduler_type="reduce_lr_on_plateau",
    learning_rate=5e-5,
    warmup_steps=20,
    fp16=True, #only on cuda
    logging_steps=10,
    save_total_limit = 1, # saves only most recent checkpoint 
    load_best_model_at_end= True, # + best one 
    save_strategy="epoch",
    save_only_model=True, # set to "False" if you want to resume training at some other point
    report_to="wandb",
    #hub_strategy="end",    # add huggingface hub info
    #hub_token="EnterToken",
    )

# setup trainer
trainer = Trainer(
    model=model, 
    tokenizer= tokenizer, 
    args = training_args,
    data_collator=collator,
    train_dataset=dataset_train,
    eval_dataset=dataset_val,  
)

# Train the model
trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
  1%|          | 18/2096 [04:17<8:14:58, 14.29s/it]
  0%|          | 10/2096 [00:24<1:26:34,  2.49s/it]

{'loss': 4.849, 'grad_norm': 6.272318363189697, 'learning_rate': 0.0005, 'epoch': 0.02}


  1%|          | 20/2096 [00:49<1:26:30,  2.50s/it]

{'loss': 3.1237, 'grad_norm': 1.508876085281372, 'learning_rate': 0.0005, 'epoch': 0.04}


  1%|▏         | 30/2096 [01:14<1:26:11,  2.50s/it]

{'loss': 2.6568, 'grad_norm': 1.3366955518722534, 'learning_rate': 0.0005, 'epoch': 0.06}


  2%|▏         | 40/2096 [01:39<1:25:41,  2.50s/it]

{'loss': 2.616, 'grad_norm': 0.7493968605995178, 'learning_rate': 0.0005, 'epoch': 0.08}


  2%|▏         | 50/2096 [02:04<1:25:09,  2.50s/it]

{'loss': 2.5898, 'grad_norm': 0.7158461213111877, 'learning_rate': 0.0005, 'epoch': 0.1}


  3%|▎         | 60/2096 [02:30<1:25:06,  2.51s/it]

{'loss': 2.5632, 'grad_norm': 1.0997432470321655, 'learning_rate': 0.0005, 'epoch': 0.11}


  3%|▎         | 70/2096 [02:55<1:25:15,  2.52s/it]

{'loss': 2.5357, 'grad_norm': 0.7623980045318604, 'learning_rate': 0.0005, 'epoch': 0.13}


  4%|▍         | 80/2096 [03:20<1:24:39,  2.52s/it]

{'loss': 2.5425, 'grad_norm': 0.7991243600845337, 'learning_rate': 0.0005, 'epoch': 0.15}


  4%|▍         | 90/2096 [03:45<1:23:45,  2.51s/it]

{'loss': 2.5245, 'grad_norm': 0.8218442797660828, 'learning_rate': 0.0005, 'epoch': 0.17}


  5%|▍         | 100/2096 [04:10<1:23:33,  2.51s/it]

{'loss': 2.5169, 'grad_norm': 1.0740547180175781, 'learning_rate': 0.0005, 'epoch': 0.19}


  5%|▌         | 110/2096 [04:35<1:23:14,  2.51s/it]

{'loss': 2.4838, 'grad_norm': 1.0208696126937866, 'learning_rate': 0.0005, 'epoch': 0.21}


  6%|▌         | 120/2096 [05:00<1:22:40,  2.51s/it]

{'loss': 2.4893, 'grad_norm': 1.0650662183761597, 'learning_rate': 0.0005, 'epoch': 0.23}


  6%|▌         | 130/2096 [05:25<1:22:11,  2.51s/it]

{'loss': 2.4502, 'grad_norm': 0.8887924551963806, 'learning_rate': 0.0005, 'epoch': 0.25}


  7%|▋         | 140/2096 [05:51<1:22:33,  2.53s/it]

{'loss': 2.4395, 'grad_norm': 1.0189565420150757, 'learning_rate': 0.0005, 'epoch': 0.27}


  7%|▋         | 150/2096 [06:16<1:19:32,  2.45s/it]

{'loss': 2.4233, 'grad_norm': 0.9571822881698608, 'learning_rate': 0.0005, 'epoch': 0.29}


  8%|▊         | 160/2096 [06:40<1:17:32,  2.40s/it]

{'loss': 2.4032, 'grad_norm': 1.0389875173568726, 'learning_rate': 0.0005, 'epoch': 0.3}


  8%|▊         | 170/2096 [07:04<1:16:54,  2.40s/it]

{'loss': 2.3811, 'grad_norm': 0.7524461150169373, 'learning_rate': 0.0005, 'epoch': 0.32}


  9%|▊         | 180/2096 [07:28<1:16:44,  2.40s/it]

{'loss': 2.3728, 'grad_norm': 0.6924772262573242, 'learning_rate': 0.0005, 'epoch': 0.34}


  9%|▉         | 190/2096 [07:52<1:16:25,  2.41s/it]

{'loss': 2.3625, 'grad_norm': 0.6229585409164429, 'learning_rate': 0.0005, 'epoch': 0.36}


 10%|▉         | 200/2096 [08:16<1:15:43,  2.40s/it]

{'loss': 2.3519, 'grad_norm': 0.7252780199050903, 'learning_rate': 0.0005, 'epoch': 0.38}


 10%|█         | 210/2096 [08:40<1:19:22,  2.53s/it]

{'loss': 2.3492, 'grad_norm': 0.6663147807121277, 'learning_rate': 0.0005, 'epoch': 0.4}


 10%|█         | 220/2096 [09:06<1:18:59,  2.53s/it]

{'loss': 2.3516, 'grad_norm': 0.6458888649940491, 'learning_rate': 0.0005, 'epoch': 0.42}


 11%|█         | 230/2096 [09:31<1:18:19,  2.52s/it]

{'loss': 2.3135, 'grad_norm': 0.7586749792098999, 'learning_rate': 0.0005, 'epoch': 0.44}


 11%|█▏        | 240/2096 [09:56<1:18:55,  2.55s/it]

{'loss': 2.3139, 'grad_norm': 0.7245934009552002, 'learning_rate': 0.0005, 'epoch': 0.46}


 12%|█▏        | 250/2096 [10:22<1:17:01,  2.50s/it]

{'loss': 2.3314, 'grad_norm': 0.657798171043396, 'learning_rate': 0.0005, 'epoch': 0.48}


 12%|█▏        | 260/2096 [10:46<1:15:47,  2.48s/it]

{'loss': 2.313, 'grad_norm': 0.6414426565170288, 'learning_rate': 0.0005, 'epoch': 0.5}


 13%|█▎        | 270/2096 [11:11<1:15:15,  2.47s/it]

{'loss': 2.3021, 'grad_norm': 0.6203900575637817, 'learning_rate': 0.0005, 'epoch': 0.51}


 13%|█▎        | 280/2096 [11:36<1:14:48,  2.47s/it]

{'loss': 2.2902, 'grad_norm': 0.6247804164886475, 'learning_rate': 0.0005, 'epoch': 0.53}


 14%|█▍        | 290/2096 [12:01<1:14:24,  2.47s/it]

{'loss': 2.296, 'grad_norm': 0.5822560787200928, 'learning_rate': 0.0005, 'epoch': 0.55}


 14%|█▍        | 300/2096 [12:25<1:14:02,  2.47s/it]

{'loss': 2.2912, 'grad_norm': 0.687946081161499, 'learning_rate': 0.0005, 'epoch': 0.57}


 15%|█▍        | 310/2096 [12:50<1:13:43,  2.48s/it]

{'loss': 2.2847, 'grad_norm': 0.6708627939224243, 'learning_rate': 0.0005, 'epoch': 0.59}


 15%|█▌        | 320/2096 [13:15<1:13:14,  2.47s/it]

{'loss': 2.2609, 'grad_norm': 0.6638692617416382, 'learning_rate': 0.0005, 'epoch': 0.61}


 16%|█▌        | 330/2096 [13:40<1:12:59,  2.48s/it]

{'loss': 2.2713, 'grad_norm': 0.5498295426368713, 'learning_rate': 0.0005, 'epoch': 0.63}


 16%|█▌        | 340/2096 [14:04<1:12:40,  2.48s/it]

{'loss': 2.245, 'grad_norm': 0.6660364866256714, 'learning_rate': 0.0005, 'epoch': 0.65}


 17%|█▋        | 350/2096 [14:29<1:13:17,  2.52s/it]

{'loss': 2.2651, 'grad_norm': 0.5867024660110474, 'learning_rate': 0.0005, 'epoch': 0.67}


 17%|█▋        | 360/2096 [14:55<1:12:58,  2.52s/it]

{'loss': 2.2416, 'grad_norm': 0.6488157510757446, 'learning_rate': 0.0005, 'epoch': 0.69}


 18%|█▊        | 370/2096 [15:20<1:12:41,  2.53s/it]

{'loss': 2.2295, 'grad_norm': 0.6869769096374512, 'learning_rate': 0.0005, 'epoch': 0.7}


 18%|█▊        | 380/2096 [15:45<1:12:12,  2.52s/it]

{'loss': 2.2234, 'grad_norm': 0.6897150278091431, 'learning_rate': 0.0005, 'epoch': 0.72}


 19%|█▊        | 390/2096 [16:10<1:11:36,  2.52s/it]

{'loss': 2.2225, 'grad_norm': 0.62090003490448, 'learning_rate': 0.0005, 'epoch': 0.74}


 19%|█▉        | 400/2096 [16:36<1:11:33,  2.53s/it]

{'loss': 2.186, 'grad_norm': 0.6215182542800903, 'learning_rate': 0.0005, 'epoch': 0.76}


 20%|█▉        | 410/2096 [17:01<1:10:43,  2.52s/it]

{'loss': 2.1934, 'grad_norm': 0.5960199236869812, 'learning_rate': 0.0005, 'epoch': 0.78}


 20%|██        | 420/2096 [17:26<1:10:32,  2.53s/it]

{'loss': 2.1944, 'grad_norm': 0.8065938949584961, 'learning_rate': 0.0005, 'epoch': 0.8}


 21%|██        | 430/2096 [17:51<1:09:58,  2.52s/it]

{'loss': 2.1898, 'grad_norm': 0.6014075875282288, 'learning_rate': 0.0005, 'epoch': 0.82}


 21%|██        | 440/2096 [18:17<1:09:33,  2.52s/it]

{'loss': 2.1743, 'grad_norm': 0.6542462706565857, 'learning_rate': 0.0005, 'epoch': 0.84}


 21%|██▏       | 450/2096 [18:42<1:09:10,  2.52s/it]

{'loss': 2.1909, 'grad_norm': 0.5119603276252747, 'learning_rate': 0.0005, 'epoch': 0.86}


 22%|██▏       | 460/2096 [19:07<1:08:46,  2.52s/it]

{'loss': 2.1734, 'grad_norm': 0.6942004561424255, 'learning_rate': 0.0005, 'epoch': 0.88}


 22%|██▏       | 470/2096 [19:32<1:08:20,  2.52s/it]

{'loss': 2.1814, 'grad_norm': 0.8493196368217468, 'learning_rate': 0.0005, 'epoch': 0.9}


 23%|██▎       | 480/2096 [19:58<1:07:55,  2.52s/it]

{'loss': 2.1481, 'grad_norm': 0.6971127390861511, 'learning_rate': 0.0005, 'epoch': 0.91}


 23%|██▎       | 490/2096 [20:23<1:07:26,  2.52s/it]

{'loss': 2.153, 'grad_norm': 0.8271818161010742, 'learning_rate': 0.0005, 'epoch': 0.93}


 24%|██▍       | 500/2096 [20:48<1:06:55,  2.52s/it]

{'loss': 2.1452, 'grad_norm': 0.6589524745941162, 'learning_rate': 0.0005, 'epoch': 0.95}


 24%|██▍       | 510/2096 [21:13<1:06:25,  2.51s/it]

{'loss': 2.1267, 'grad_norm': 0.5349586606025696, 'learning_rate': 0.0005, 'epoch': 0.97}


 25%|██▍       | 520/2096 [21:38<1:05:58,  2.51s/it]

{'loss': 2.1528, 'grad_norm': 0.5191571116447449, 'learning_rate': 0.0005, 'epoch': 0.99}


 25%|██▌       | 524/2096 [21:48<1:05:46,  2.51s/it]
 25%|██▌       | 524/2096 [23:45<1:05:46,  2.51s/it]

{'eval_loss': 2.1198012828826904, 'eval_runtime': 114.9075, 'eval_samples_per_second': 73.076, 'eval_steps_per_second': 9.138, 'epoch': 1.0}


 25%|██▌       | 530/2096 [24:00<3:38:08,  8.36s/it] 

{'loss': 2.1578, 'grad_norm': 0.5229832530021667, 'learning_rate': 0.0005, 'epoch': 1.01}


 26%|██▌       | 540/2096 [24:24<1:08:26,  2.64s/it]

{'loss': 2.1252, 'grad_norm': 0.5483594536781311, 'learning_rate': 0.0005, 'epoch': 1.03}


 26%|██▌       | 550/2096 [24:49<1:03:47,  2.48s/it]

{'loss': 2.1181, 'grad_norm': 0.5035076141357422, 'learning_rate': 0.0005, 'epoch': 1.05}


 27%|██▋       | 560/2096 [25:14<1:03:17,  2.47s/it]

{'loss': 2.1109, 'grad_norm': 0.6755514740943909, 'learning_rate': 0.0005, 'epoch': 1.07}


 27%|██▋       | 570/2096 [25:39<1:02:52,  2.47s/it]

{'loss': 2.1139, 'grad_norm': 0.7810893654823303, 'learning_rate': 0.0005, 'epoch': 1.09}


 28%|██▊       | 580/2096 [26:03<1:02:32,  2.48s/it]

{'loss': 2.1037, 'grad_norm': 0.6707762479782104, 'learning_rate': 0.0005, 'epoch': 1.11}


 28%|██▊       | 590/2096 [26:28<1:02:12,  2.48s/it]

{'loss': 2.1273, 'grad_norm': 0.5444919466972351, 'learning_rate': 0.0005, 'epoch': 1.12}


 29%|██▊       | 600/2096 [26:53<1:01:54,  2.48s/it]

{'loss': 2.1158, 'grad_norm': 0.6166770458221436, 'learning_rate': 0.0005, 'epoch': 1.14}


 29%|██▉       | 610/2096 [27:18<1:01:24,  2.48s/it]

{'loss': 2.0871, 'grad_norm': 0.49031081795692444, 'learning_rate': 0.0005, 'epoch': 1.16}


 30%|██▉       | 620/2096 [27:43<1:01:07,  2.48s/it]

{'loss': 2.1224, 'grad_norm': 0.5745970010757446, 'learning_rate': 0.0005, 'epoch': 1.18}


 30%|███       | 630/2096 [28:07<1:00:48,  2.49s/it]

{'loss': 2.1158, 'grad_norm': 0.7179087996482849, 'learning_rate': 0.0005, 'epoch': 1.2}


 31%|███       | 640/2096 [28:32<1:00:21,  2.49s/it]

{'loss': 2.1098, 'grad_norm': 0.6040222644805908, 'learning_rate': 0.0005, 'epoch': 1.22}


 31%|███       | 650/2096 [28:57<59:54,  2.49s/it]  

{'loss': 2.0903, 'grad_norm': 0.6214043498039246, 'learning_rate': 0.0005, 'epoch': 1.24}


 31%|███▏      | 660/2096 [29:22<59:30,  2.49s/it]

{'loss': 2.0963, 'grad_norm': 0.613247275352478, 'learning_rate': 0.0005, 'epoch': 1.26}


 32%|███▏      | 670/2096 [29:47<59:09,  2.49s/it]

{'loss': 2.1105, 'grad_norm': 0.498766154050827, 'learning_rate': 0.0005, 'epoch': 1.28}


 32%|███▏      | 680/2096 [30:12<58:38,  2.48s/it]

{'loss': 2.102, 'grad_norm': 0.6117123365402222, 'learning_rate': 0.0005, 'epoch': 1.3}


 33%|███▎      | 690/2096 [30:37<58:16,  2.49s/it]

{'loss': 2.0912, 'grad_norm': 0.6242431998252869, 'learning_rate': 0.0005, 'epoch': 1.31}


 33%|███▎      | 700/2096 [31:02<57:56,  2.49s/it]

{'loss': 2.102, 'grad_norm': 0.5441591143608093, 'learning_rate': 0.0005, 'epoch': 1.33}


 34%|███▍      | 710/2096 [31:26<57:23,  2.48s/it]

{'loss': 2.076, 'grad_norm': 0.6195278167724609, 'learning_rate': 0.0005, 'epoch': 1.35}


 34%|███▍      | 720/2096 [31:51<56:55,  2.48s/it]

{'loss': 2.0891, 'grad_norm': 0.584503173828125, 'learning_rate': 0.0005, 'epoch': 1.37}


 35%|███▍      | 730/2096 [32:16<56:19,  2.47s/it]

{'loss': 2.0694, 'grad_norm': 0.5284234285354614, 'learning_rate': 0.0005, 'epoch': 1.39}


 35%|███▌      | 740/2096 [32:41<55:58,  2.48s/it]

{'loss': 2.091, 'grad_norm': 0.6110583543777466, 'learning_rate': 0.0005, 'epoch': 1.41}


 36%|███▌      | 750/2096 [33:06<55:30,  2.47s/it]

{'loss': 2.0549, 'grad_norm': 0.4935417175292969, 'learning_rate': 0.0005, 'epoch': 1.43}


 36%|███▋      | 760/2096 [33:30<55:02,  2.47s/it]

{'loss': 2.0857, 'grad_norm': 0.5839532017707825, 'learning_rate': 0.0005, 'epoch': 1.45}


 37%|███▋      | 770/2096 [33:55<54:35,  2.47s/it]

{'loss': 2.0692, 'grad_norm': 0.476393461227417, 'learning_rate': 0.0005, 'epoch': 1.47}


 37%|███▋      | 780/2096 [34:20<54:12,  2.47s/it]

{'loss': 2.0735, 'grad_norm': 0.5042306184768677, 'learning_rate': 0.0005, 'epoch': 1.49}


 38%|███▊      | 790/2096 [34:44<53:45,  2.47s/it]

{'loss': 2.0439, 'grad_norm': 0.6084648370742798, 'learning_rate': 0.0005, 'epoch': 1.51}


 38%|███▊      | 800/2096 [35:09<53:26,  2.47s/it]

{'loss': 2.0591, 'grad_norm': 0.608302891254425, 'learning_rate': 0.0005, 'epoch': 1.52}


 39%|███▊      | 810/2096 [35:34<52:59,  2.47s/it]

{'loss': 2.054, 'grad_norm': 0.5436345338821411, 'learning_rate': 0.0005, 'epoch': 1.54}


 39%|███▉      | 820/2096 [35:59<52:37,  2.47s/it]

{'loss': 2.073, 'grad_norm': 0.5939953327178955, 'learning_rate': 0.0005, 'epoch': 1.56}


 40%|███▉      | 830/2096 [36:23<52:16,  2.48s/it]

{'loss': 2.0522, 'grad_norm': 0.6693093180656433, 'learning_rate': 0.0005, 'epoch': 1.58}


 40%|████      | 840/2096 [36:48<51:53,  2.48s/it]

{'loss': 2.0543, 'grad_norm': 0.5161689519882202, 'learning_rate': 0.0005, 'epoch': 1.6}


 41%|████      | 850/2096 [37:13<51:39,  2.49s/it]

{'loss': 2.0825, 'grad_norm': 0.6182892918586731, 'learning_rate': 0.0005, 'epoch': 1.62}


 41%|████      | 860/2096 [37:38<51:08,  2.48s/it]

{'loss': 2.0487, 'grad_norm': 0.49574872851371765, 'learning_rate': 0.0005, 'epoch': 1.64}


 42%|████▏     | 870/2096 [38:03<50:51,  2.49s/it]

{'loss': 2.0428, 'grad_norm': 0.5037855505943298, 'learning_rate': 0.0005, 'epoch': 1.66}


 42%|████▏     | 880/2096 [38:28<50:21,  2.48s/it]

{'loss': 2.0522, 'grad_norm': 0.5726516246795654, 'learning_rate': 0.0005, 'epoch': 1.68}


 42%|████▏     | 890/2096 [38:52<49:55,  2.48s/it]

{'loss': 2.0546, 'grad_norm': 0.6078126430511475, 'learning_rate': 0.0005, 'epoch': 1.7}


 43%|████▎     | 900/2096 [39:17<49:39,  2.49s/it]

{'loss': 2.0429, 'grad_norm': 0.5060900449752808, 'learning_rate': 0.0005, 'epoch': 1.71}


 43%|████▎     | 910/2096 [39:42<49:09,  2.49s/it]

{'loss': 2.0467, 'grad_norm': 0.607549786567688, 'learning_rate': 0.0005, 'epoch': 1.73}


 44%|████▍     | 920/2096 [40:07<48:49,  2.49s/it]

{'loss': 2.0472, 'grad_norm': 0.5276082754135132, 'learning_rate': 0.0005, 'epoch': 1.75}


 44%|████▍     | 930/2096 [40:32<48:17,  2.49s/it]

{'loss': 2.0218, 'grad_norm': 0.4966421127319336, 'learning_rate': 0.0005, 'epoch': 1.77}


 45%|████▍     | 940/2096 [40:57<47:57,  2.49s/it]

{'loss': 2.0238, 'grad_norm': 0.5458455085754395, 'learning_rate': 0.0005, 'epoch': 1.79}


 45%|████▌     | 950/2096 [41:22<47:30,  2.49s/it]

{'loss': 2.0492, 'grad_norm': 0.570103645324707, 'learning_rate': 0.0005, 'epoch': 1.81}


 46%|████▌     | 960/2096 [41:47<47:01,  2.48s/it]

{'loss': 2.0505, 'grad_norm': 0.4486362636089325, 'learning_rate': 0.0005, 'epoch': 1.83}


 46%|████▋     | 970/2096 [42:11<46:37,  2.48s/it]

{'loss': 2.0513, 'grad_norm': 0.5724099278450012, 'learning_rate': 0.0005, 'epoch': 1.85}


 47%|████▋     | 980/2096 [42:36<46:10,  2.48s/it]

{'loss': 2.0204, 'grad_norm': 0.5262743830680847, 'learning_rate': 0.0005, 'epoch': 1.87}


 47%|████▋     | 990/2096 [43:01<45:45,  2.48s/it]

{'loss': 2.0442, 'grad_norm': 0.6029306054115295, 'learning_rate': 0.0005, 'epoch': 1.89}


 48%|████▊     | 1000/2096 [43:26<45:17,  2.48s/it]

{'loss': 2.0353, 'grad_norm': 0.5252447128295898, 'learning_rate': 0.0005, 'epoch': 1.91}


 48%|████▊     | 1010/2096 [43:51<44:47,  2.47s/it]

{'loss': 2.0368, 'grad_norm': 0.6718998551368713, 'learning_rate': 0.0005, 'epoch': 1.92}


 49%|████▊     | 1020/2096 [44:15<44:19,  2.47s/it]

{'loss': 2.0153, 'grad_norm': 0.5707575678825378, 'learning_rate': 0.0005, 'epoch': 1.94}


 49%|████▉     | 1030/2096 [44:40<43:54,  2.47s/it]

{'loss': 2.0087, 'grad_norm': 0.4540582001209259, 'learning_rate': 0.0005, 'epoch': 1.96}


 50%|████▉     | 1040/2096 [45:05<43:27,  2.47s/it]

{'loss': 2.0433, 'grad_norm': 0.5452066659927368, 'learning_rate': 0.0005, 'epoch': 1.98}


 50%|█████     | 1049/2096 [45:27<43:04,  2.47s/it]
 50%|█████     | 1049/2096 [47:23<43:04,  2.47s/it]

{'eval_loss': 2.006798028945923, 'eval_runtime': 114.4183, 'eval_samples_per_second': 73.389, 'eval_steps_per_second': 9.177, 'epoch': 2.0}


 50%|█████     | 1050/2096 [47:25<10:49:42, 37.27s/it]

{'loss': 2.0161, 'grad_norm': 0.4466921389102936, 'learning_rate': 0.0005, 'epoch': 2.0}


 51%|█████     | 1060/2096 [47:50<59:39,  3.46s/it]   

{'loss': 2.0212, 'grad_norm': 0.5367615222930908, 'learning_rate': 0.0005, 'epoch': 2.02}


 51%|█████     | 1070/2096 [48:15<42:49,  2.50s/it]

{'loss': 2.0202, 'grad_norm': 0.5714686512947083, 'learning_rate': 0.0005, 'epoch': 2.04}


 52%|█████▏    | 1080/2096 [48:40<42:00,  2.48s/it]

{'loss': 2.0012, 'grad_norm': 0.4649885892868042, 'learning_rate': 0.0005, 'epoch': 2.06}


 52%|█████▏    | 1090/2096 [49:05<41:33,  2.48s/it]

{'loss': 2.0253, 'grad_norm': 0.48176759481430054, 'learning_rate': 0.0005, 'epoch': 2.08}


 52%|█████▏    | 1100/2096 [49:29<41:15,  2.49s/it]

{'loss': 2.0084, 'grad_norm': 0.5923572182655334, 'learning_rate': 0.0005, 'epoch': 2.1}


 53%|█████▎    | 1110/2096 [49:54<40:49,  2.48s/it]

{'loss': 2.0206, 'grad_norm': 0.5229465365409851, 'learning_rate': 0.0005, 'epoch': 2.11}


 53%|█████▎    | 1120/2096 [50:19<40:32,  2.49s/it]

{'loss': 2.0121, 'grad_norm': 0.4495607912540436, 'learning_rate': 0.0005, 'epoch': 2.13}


 54%|█████▍    | 1130/2096 [50:44<40:00,  2.48s/it]

{'loss': 2.0169, 'grad_norm': 0.535701334476471, 'learning_rate': 0.0005, 'epoch': 2.15}


 54%|█████▍    | 1140/2096 [51:09<39:37,  2.49s/it]

{'loss': 1.9954, 'grad_norm': 0.5144317746162415, 'learning_rate': 0.0005, 'epoch': 2.17}


 55%|█████▍    | 1150/2096 [51:34<39:16,  2.49s/it]

{'loss': 2.0138, 'grad_norm': 0.5236411690711975, 'learning_rate': 0.0005, 'epoch': 2.19}


 55%|█████▌    | 1160/2096 [51:59<38:49,  2.49s/it]

{'loss': 1.9938, 'grad_norm': 0.7902648448944092, 'learning_rate': 0.0005, 'epoch': 2.21}


 56%|█████▌    | 1170/2096 [52:23<38:20,  2.48s/it]

{'loss': 1.9967, 'grad_norm': 0.4428461492061615, 'learning_rate': 0.0005, 'epoch': 2.23}


 56%|█████▋    | 1180/2096 [52:48<37:54,  2.48s/it]

{'loss': 2.0043, 'grad_norm': 0.5211853981018066, 'learning_rate': 0.0005, 'epoch': 2.25}


 57%|█████▋    | 1190/2096 [53:13<37:30,  2.48s/it]

{'loss': 2.0019, 'grad_norm': 0.4783218502998352, 'learning_rate': 0.0005, 'epoch': 2.27}


 57%|█████▋    | 1200/2096 [53:38<37:10,  2.49s/it]

{'loss': 1.9989, 'grad_norm': 0.5837000012397766, 'learning_rate': 0.0005, 'epoch': 2.29}


 58%|█████▊    | 1210/2096 [54:03<36:46,  2.49s/it]

{'loss': 1.9931, 'grad_norm': 0.49474453926086426, 'learning_rate': 0.0005, 'epoch': 2.31}


 58%|█████▊    | 1220/2096 [54:28<36:13,  2.48s/it]

{'loss': 1.9976, 'grad_norm': 0.7354146838188171, 'learning_rate': 0.0005, 'epoch': 2.32}


 59%|█████▊    | 1230/2096 [54:53<35:48,  2.48s/it]

{'loss': 2.0182, 'grad_norm': 0.4627850651741028, 'learning_rate': 0.0005, 'epoch': 2.34}


 59%|█████▉    | 1240/2096 [55:17<35:21,  2.48s/it]

{'loss': 1.9911, 'grad_norm': 0.6549704670906067, 'learning_rate': 0.0005, 'epoch': 2.36}


 60%|█████▉    | 1250/2096 [55:42<34:56,  2.48s/it]

{'loss': 1.9972, 'grad_norm': 0.4672211706638336, 'learning_rate': 0.0005, 'epoch': 2.38}


 60%|██████    | 1260/2096 [56:07<34:30,  2.48s/it]

{'loss': 1.9988, 'grad_norm': 0.5100688338279724, 'learning_rate': 0.0005, 'epoch': 2.4}


 61%|██████    | 1270/2096 [56:32<34:05,  2.48s/it]

{'loss': 1.9976, 'grad_norm': 0.5386558771133423, 'learning_rate': 0.0005, 'epoch': 2.42}


 61%|██████    | 1280/2096 [56:57<33:38,  2.47s/it]

{'loss': 2.0179, 'grad_norm': 0.4740334153175354, 'learning_rate': 0.0005, 'epoch': 2.44}


 62%|██████▏   | 1290/2096 [57:21<33:12,  2.47s/it]

{'loss': 2.0202, 'grad_norm': 0.46268486976623535, 'learning_rate': 0.0005, 'epoch': 2.46}


 62%|██████▏   | 1300/2096 [57:46<32:48,  2.47s/it]

{'loss': 1.9994, 'grad_norm': 0.5077358484268188, 'learning_rate': 0.0005, 'epoch': 2.48}


 62%|██████▎   | 1310/2096 [58:11<32:19,  2.47s/it]

{'loss': 1.992, 'grad_norm': 0.5200123190879822, 'learning_rate': 0.0005, 'epoch': 2.5}


 63%|██████▎   | 1320/2096 [58:35<31:56,  2.47s/it]

{'loss': 2.0054, 'grad_norm': 0.4618810713291168, 'learning_rate': 0.0005, 'epoch': 2.51}


 63%|██████▎   | 1330/2096 [59:00<31:31,  2.47s/it]

{'loss': 1.9813, 'grad_norm': 0.4602905809879303, 'learning_rate': 0.0005, 'epoch': 2.53}


 64%|██████▍   | 1340/2096 [59:25<31:07,  2.47s/it]

{'loss': 2.0056, 'grad_norm': 0.48245421051979065, 'learning_rate': 0.0005, 'epoch': 2.55}


 64%|██████▍   | 1350/2096 [59:50<30:46,  2.48s/it]

{'loss': 1.9855, 'grad_norm': 0.5071650743484497, 'learning_rate': 0.0005, 'epoch': 2.57}


 65%|██████▍   | 1360/2096 [1:00:14<30:20,  2.47s/it]

{'loss': 1.9812, 'grad_norm': 0.5310693979263306, 'learning_rate': 0.0005, 'epoch': 2.59}


 65%|██████▌   | 1370/2096 [1:00:39<29:57,  2.48s/it]

{'loss': 1.992, 'grad_norm': 0.6211186647415161, 'learning_rate': 0.0005, 'epoch': 2.61}


 66%|██████▌   | 1380/2096 [1:01:04<29:37,  2.48s/it]

{'loss': 1.991, 'grad_norm': 0.570671796798706, 'learning_rate': 0.0005, 'epoch': 2.63}


 66%|██████▋   | 1390/2096 [1:01:29<29:10,  2.48s/it]

{'loss': 1.9877, 'grad_norm': 0.49815991520881653, 'learning_rate': 0.0005, 'epoch': 2.65}


 67%|██████▋   | 1400/2096 [1:01:54<29:12,  2.52s/it]

{'loss': 1.9849, 'grad_norm': 0.5839947462081909, 'learning_rate': 0.0005, 'epoch': 2.67}


 67%|██████▋   | 1410/2096 [1:02:19<28:30,  2.49s/it]

{'loss': 1.9914, 'grad_norm': 0.4895860552787781, 'learning_rate': 0.0005, 'epoch': 2.69}


 68%|██████▊   | 1420/2096 [1:02:43<27:26,  2.44s/it]

{'loss': 2.001, 'grad_norm': 0.5143311619758606, 'learning_rate': 0.0005, 'epoch': 2.71}


 68%|██████▊   | 1430/2096 [1:03:07<26:44,  2.41s/it]

{'loss': 1.9694, 'grad_norm': 0.48150232434272766, 'learning_rate': 0.0005, 'epoch': 2.72}


 69%|██████▊   | 1440/2096 [1:03:31<26:18,  2.41s/it]

{'loss': 1.9825, 'grad_norm': 0.5410529375076294, 'learning_rate': 0.0005, 'epoch': 2.74}


 69%|██████▉   | 1450/2096 [1:03:56<27:14,  2.53s/it]

{'loss': 1.9755, 'grad_norm': 0.5147268176078796, 'learning_rate': 0.0005, 'epoch': 2.76}


 70%|██████▉   | 1460/2096 [1:04:21<25:31,  2.41s/it]

{'loss': 1.9896, 'grad_norm': 0.5437952876091003, 'learning_rate': 0.0005, 'epoch': 2.78}


 70%|███████   | 1470/2096 [1:04:45<25:04,  2.40s/it]

{'loss': 1.9748, 'grad_norm': 0.5027430057525635, 'learning_rate': 0.0005, 'epoch': 2.8}


 71%|███████   | 1480/2096 [1:05:09<24:38,  2.40s/it]

{'loss': 1.9746, 'grad_norm': 0.47641435265541077, 'learning_rate': 0.0005, 'epoch': 2.82}


 71%|███████   | 1490/2096 [1:05:33<24:22,  2.41s/it]

{'loss': 1.9795, 'grad_norm': 0.5555567145347595, 'learning_rate': 0.0005, 'epoch': 2.84}


 72%|███████▏  | 1500/2096 [1:05:57<23:50,  2.40s/it]

{'loss': 1.981, 'grad_norm': 0.5788444876670837, 'learning_rate': 0.0005, 'epoch': 2.86}


 72%|███████▏  | 1510/2096 [1:06:22<24:27,  2.50s/it]

{'loss': 1.9924, 'grad_norm': 0.5602027177810669, 'learning_rate': 0.0005, 'epoch': 2.88}


 73%|███████▎  | 1520/2096 [1:06:46<23:09,  2.41s/it]

{'loss': 1.969, 'grad_norm': 0.5817355513572693, 'learning_rate': 0.0005, 'epoch': 2.9}


 73%|███████▎  | 1530/2096 [1:07:10<22:39,  2.40s/it]

{'loss': 1.9654, 'grad_norm': 0.6433917284011841, 'learning_rate': 0.0005, 'epoch': 2.91}


 73%|███████▎  | 1540/2096 [1:07:34<22:14,  2.40s/it]

{'loss': 1.9796, 'grad_norm': 0.6077769994735718, 'learning_rate': 0.0005, 'epoch': 2.93}


 74%|███████▍  | 1550/2096 [1:07:58<21:48,  2.40s/it]

{'loss': 1.9605, 'grad_norm': 0.4295545816421509, 'learning_rate': 0.0005, 'epoch': 2.95}


 74%|███████▍  | 1560/2096 [1:08:22<21:29,  2.41s/it]

{'loss': 1.9664, 'grad_norm': 0.448434978723526, 'learning_rate': 0.0005, 'epoch': 2.97}


 75%|███████▍  | 1570/2096 [1:08:47<21:59,  2.51s/it]

{'loss': 1.9818, 'grad_norm': 0.5367178320884705, 'learning_rate': 0.0005, 'epoch': 2.99}


 75%|███████▌  | 1574/2096 [1:08:57<21:54,  2.52s/it]
 75%|███████▌  | 1574/2096 [1:10:55<21:54,  2.52s/it]

{'eval_loss': 1.9547278881072998, 'eval_runtime': 116.5218, 'eval_samples_per_second': 72.064, 'eval_steps_per_second': 9.011, 'epoch': 3.0}


 75%|███████▌  | 1580/2096 [1:11:10<1:12:58,  8.48s/it]

{'loss': 1.9772, 'grad_norm': 0.473624587059021, 'learning_rate': 0.0005, 'epoch': 3.01}


 76%|███████▌  | 1590/2096 [1:11:36<22:38,  2.69s/it]  

{'loss': 1.9473, 'grad_norm': 0.5243654251098633, 'learning_rate': 0.0005, 'epoch': 3.03}


 76%|███████▋  | 1600/2096 [1:12:01<20:47,  2.52s/it]

{'loss': 1.9767, 'grad_norm': 0.5592474937438965, 'learning_rate': 0.0005, 'epoch': 3.05}


 77%|███████▋  | 1610/2096 [1:12:25<19:31,  2.41s/it]

{'loss': 1.9558, 'grad_norm': 0.4319092631340027, 'learning_rate': 0.0005, 'epoch': 3.07}


 77%|███████▋  | 1620/2096 [1:12:49<19:20,  2.44s/it]

{'loss': 1.987, 'grad_norm': 0.4389091730117798, 'learning_rate': 0.0005, 'epoch': 3.09}


 78%|███████▊  | 1630/2096 [1:13:14<19:22,  2.49s/it]

{'loss': 1.9579, 'grad_norm': 0.6627464294433594, 'learning_rate': 0.0005, 'epoch': 3.11}


 78%|███████▊  | 1640/2096 [1:13:39<18:34,  2.44s/it]

{'loss': 1.9654, 'grad_norm': 0.5108437538146973, 'learning_rate': 0.0005, 'epoch': 3.12}


 79%|███████▊  | 1650/2096 [1:14:03<18:12,  2.45s/it]

{'loss': 1.9684, 'grad_norm': 0.5572566390037537, 'learning_rate': 0.0005, 'epoch': 3.14}


 79%|███████▉  | 1660/2096 [1:14:28<17:52,  2.46s/it]

{'loss': 1.9588, 'grad_norm': 0.5366126298904419, 'learning_rate': 0.0005, 'epoch': 3.16}


 80%|███████▉  | 1670/2096 [1:14:52<17:19,  2.44s/it]

{'loss': 1.9644, 'grad_norm': 0.42596936225891113, 'learning_rate': 0.0005, 'epoch': 3.18}


 80%|████████  | 1680/2096 [1:15:16<17:06,  2.47s/it]

{'loss': 1.9694, 'grad_norm': 0.5594634413719177, 'learning_rate': 0.0005, 'epoch': 3.2}


 81%|████████  | 1690/2096 [1:15:42<17:03,  2.52s/it]

{'loss': 1.9706, 'grad_norm': 0.5379141569137573, 'learning_rate': 0.0005, 'epoch': 3.22}


 81%|████████  | 1700/2096 [1:16:07<16:45,  2.54s/it]

{'loss': 1.9573, 'grad_norm': 0.4584985673427582, 'learning_rate': 0.0005, 'epoch': 3.24}


 82%|████████▏ | 1710/2096 [1:16:32<16:23,  2.55s/it]

{'loss': 1.9633, 'grad_norm': 0.4019224941730499, 'learning_rate': 0.0005, 'epoch': 3.26}


 82%|████████▏ | 1720/2096 [1:16:58<15:57,  2.55s/it]

{'loss': 1.9431, 'grad_norm': 0.6185517311096191, 'learning_rate': 0.0005, 'epoch': 3.28}


 83%|████████▎ | 1730/2096 [1:17:23<15:28,  2.54s/it]

{'loss': 1.9583, 'grad_norm': 0.48843204975128174, 'learning_rate': 0.0005, 'epoch': 3.3}


 83%|████████▎ | 1740/2096 [1:17:48<15:05,  2.54s/it]

{'loss': 1.954, 'grad_norm': 0.46324339509010315, 'learning_rate': 0.0005, 'epoch': 3.32}


 83%|████████▎ | 1750/2096 [1:18:14<14:43,  2.55s/it]

{'loss': 1.9668, 'grad_norm': 0.5270750522613525, 'learning_rate': 0.0005, 'epoch': 3.33}


 84%|████████▍ | 1760/2096 [1:18:39<14:20,  2.56s/it]

{'loss': 1.9535, 'grad_norm': 0.49303510785102844, 'learning_rate': 0.0005, 'epoch': 3.35}


 84%|████████▍ | 1770/2096 [1:19:05<14:08,  2.60s/it]

{'loss': 1.9534, 'grad_norm': 0.5045322179794312, 'learning_rate': 0.0005, 'epoch': 3.37}


 85%|████████▍ | 1780/2096 [1:19:31<13:39,  2.59s/it]

{'loss': 1.9432, 'grad_norm': 0.47860482335090637, 'learning_rate': 0.0005, 'epoch': 3.39}


 85%|████████▌ | 1790/2096 [1:19:57<13:05,  2.57s/it]

{'loss': 1.9555, 'grad_norm': 0.5104812979698181, 'learning_rate': 0.0005, 'epoch': 3.41}


 86%|████████▌ | 1800/2096 [1:20:23<12:40,  2.57s/it]

{'loss': 1.9493, 'grad_norm': 0.4695645272731781, 'learning_rate': 0.0005, 'epoch': 3.43}


 86%|████████▋ | 1810/2096 [1:20:48<12:05,  2.54s/it]

{'loss': 1.9483, 'grad_norm': 0.5520764589309692, 'learning_rate': 0.0005, 'epoch': 3.45}


 87%|████████▋ | 1820/2096 [1:21:14<11:40,  2.54s/it]

{'loss': 1.961, 'grad_norm': 0.5209382772445679, 'learning_rate': 0.0005, 'epoch': 3.47}


 87%|████████▋ | 1830/2096 [1:21:39<11:22,  2.57s/it]

{'loss': 1.926, 'grad_norm': 0.5416978597640991, 'learning_rate': 0.0005, 'epoch': 3.49}


 88%|████████▊ | 1840/2096 [1:22:05<10:52,  2.55s/it]

{'loss': 1.9374, 'grad_norm': 0.5232625603675842, 'learning_rate': 0.0005, 'epoch': 3.51}


 88%|████████▊ | 1850/2096 [1:22:30<10:27,  2.55s/it]

{'loss': 1.9393, 'grad_norm': 0.44206520915031433, 'learning_rate': 0.0005, 'epoch': 3.52}


 89%|████████▊ | 1860/2096 [1:22:56<10:03,  2.56s/it]

{'loss': 1.9416, 'grad_norm': 0.4870043396949768, 'learning_rate': 0.0005, 'epoch': 3.54}


 89%|████████▉ | 1870/2096 [1:23:22<09:41,  2.57s/it]

{'loss': 1.9525, 'grad_norm': 0.47735607624053955, 'learning_rate': 0.0005, 'epoch': 3.56}


 90%|████████▉ | 1880/2096 [1:23:48<09:16,  2.58s/it]

{'loss': 1.9535, 'grad_norm': 0.4531016945838928, 'learning_rate': 0.0005, 'epoch': 3.58}


 90%|█████████ | 1890/2096 [1:24:13<08:43,  2.54s/it]

{'loss': 1.9426, 'grad_norm': 0.5032697916030884, 'learning_rate': 0.0005, 'epoch': 3.6}


 91%|█████████ | 1900/2096 [1:24:39<08:31,  2.61s/it]

{'loss': 1.9632, 'grad_norm': 0.46632465720176697, 'learning_rate': 0.0005, 'epoch': 3.62}


 91%|█████████ | 1910/2096 [1:25:05<08:06,  2.61s/it]

{'loss': 1.9288, 'grad_norm': 0.4804592430591583, 'learning_rate': 0.0005, 'epoch': 3.64}


 92%|█████████▏| 1920/2096 [1:25:30<07:26,  2.54s/it]

{'loss': 1.9304, 'grad_norm': 0.48542261123657227, 'learning_rate': 0.0005, 'epoch': 3.66}


 92%|█████████▏| 1930/2096 [1:25:56<07:03,  2.55s/it]

{'loss': 1.9367, 'grad_norm': 0.4735771119594574, 'learning_rate': 0.0005, 'epoch': 3.68}


 93%|█████████▎| 1940/2096 [1:26:21<06:38,  2.55s/it]

{'loss': 1.945, 'grad_norm': 0.4760046601295471, 'learning_rate': 0.0005, 'epoch': 3.7}


 93%|█████████▎| 1950/2096 [1:26:47<06:11,  2.54s/it]

{'loss': 1.9262, 'grad_norm': 0.5434844493865967, 'learning_rate': 0.0005, 'epoch': 3.72}


 94%|█████████▎| 1960/2096 [1:27:12<05:46,  2.55s/it]

{'loss': 1.9447, 'grad_norm': 0.543496310710907, 'learning_rate': 0.0005, 'epoch': 3.73}


 94%|█████████▍| 1970/2096 [1:27:38<05:19,  2.54s/it]

{'loss': 1.938, 'grad_norm': 0.4517475962638855, 'learning_rate': 0.0005, 'epoch': 3.75}


 94%|█████████▍| 1980/2096 [1:28:03<04:53,  2.53s/it]

{'loss': 1.9461, 'grad_norm': 0.5462824702262878, 'learning_rate': 0.0005, 'epoch': 3.77}


 95%|█████████▍| 1990/2096 [1:28:29<04:30,  2.55s/it]

{'loss': 1.9289, 'grad_norm': 0.4141314625740051, 'learning_rate': 0.0005, 'epoch': 3.79}


 95%|█████████▌| 2000/2096 [1:28:54<04:04,  2.55s/it]

{'loss': 1.9441, 'grad_norm': 0.5382769703865051, 'learning_rate': 0.0005, 'epoch': 3.81}


 96%|█████████▌| 2010/2096 [1:29:20<03:39,  2.55s/it]

{'loss': 1.9449, 'grad_norm': 0.4685162305831909, 'learning_rate': 0.0005, 'epoch': 3.83}


 96%|█████████▋| 2020/2096 [1:29:45<03:13,  2.55s/it]

{'loss': 1.9263, 'grad_norm': 0.52872633934021, 'learning_rate': 0.0005, 'epoch': 3.85}


 97%|█████████▋| 2030/2096 [1:30:10<02:47,  2.53s/it]

{'loss': 1.9554, 'grad_norm': 0.4658665359020233, 'learning_rate': 0.0005, 'epoch': 3.87}


 97%|█████████▋| 2040/2096 [1:30:36<02:22,  2.55s/it]

{'loss': 1.9292, 'grad_norm': 0.5020055770874023, 'learning_rate': 0.0005, 'epoch': 3.89}


 98%|█████████▊| 2050/2096 [1:31:01<01:56,  2.54s/it]

{'loss': 1.9341, 'grad_norm': 0.4749789237976074, 'learning_rate': 0.0005, 'epoch': 3.91}


 98%|█████████▊| 2060/2096 [1:31:26<01:31,  2.54s/it]

{'loss': 1.9633, 'grad_norm': 0.6083621978759766, 'learning_rate': 0.0005, 'epoch': 3.92}


 99%|█████████▉| 2070/2096 [1:31:52<01:06,  2.54s/it]

{'loss': 1.9414, 'grad_norm': 0.5347632169723511, 'learning_rate': 0.0005, 'epoch': 3.94}


 99%|█████████▉| 2080/2096 [1:32:18<00:40,  2.55s/it]

{'loss': 1.9287, 'grad_norm': 0.5459960103034973, 'learning_rate': 0.0005, 'epoch': 3.96}


100%|█████████▉| 2090/2096 [1:32:43<00:15,  2.56s/it]

{'loss': 1.9342, 'grad_norm': 0.4895651638507843, 'learning_rate': 0.0005, 'epoch': 3.98}


100%|██████████| 2096/2096 [1:32:59<00:00,  2.60s/it]
100%|██████████| 2096/2096 [1:34:58<00:00,  2.60s/it]

{'eval_loss': 1.9172981977462769, 'eval_runtime': 118.857, 'eval_samples_per_second': 70.648, 'eval_steps_per_second': 8.834, 'epoch': 3.99}


There were missing keys in the checkpoint model loaded: ['lm_head.weight'].
100%|██████████| 2096/2096 [1:34:59<00:00,  2.72s/it]

{'train_runtime': 5699.9543, 'train_samples_per_second': 23.574, 'train_steps_per_second': 0.368, 'train_loss': 2.100743389766635, 'epoch': 3.99}





TrainOutput(global_step=2096, training_loss=2.100743389766635, metrics={'train_runtime': 5699.9543, 'train_samples_per_second': 23.574, 'train_steps_per_second': 0.368, 'train_loss': 2.100743389766635, 'epoch': 3.99})

## Inference

In [None]:
from transformers import pipeline, GenerationConfig

model_local = GPT2LMHeadModel.from_pretrained("./model/gpt-2/checkpoint-2096/")

model.eval()

#what to give as input
input = tokenizer["BOS_None"]

generation_config = GenerationConfig(
    max_length = 100, # input + output
    max_new_tokens = 100, # output only - overwrites max_length 
    # could also set min tokens
    # force_eos_token = [tokenizer['EOS_None']]
)

sequence = model.generate(generation_config)
sequence

In [None]:
# Use DAW

import partitura.score as score
from partitura import save_score_midi
# use perform part
part = score.Part(id="example")

note = score.Note(step="C", octave="3") 

part.add(note,0.,1.)
print(part.pretty())
midi_score = score.Score(partlist=[part])

f_path = "./data/example.mid"
_ = save_score_midi(midi_score, f_path)

input = tokenizer(Path(f_path))


## Back to MIDI

In [21]:
import fluidsynth
#!pip install pretty_midi jupyterlab
import pretty_midi 
from IPython.display import Audio
sequence = [[1, 226, 43, 110, 124, 188, 48, 110, 124, 188, 38, 110, 124, 188, 53, 111, 124, 188, 57, 110]]
#print(sequence.tolist())
_ = tokenizer.tokens_to_midi(tokens=sequence)
print(_.note_num())
fpath = os.path.join("data","example","example.mid")
_.dump_midi(fpath)
midi_data = pretty_midi.PrettyMIDI(midi_file=fpath)
audio_data = midi_data.fluidsynth(fs=44100, sf2_path=os.path.join("data","soundfont","Roland_SC-55.sf2"))
print(len(audio_data))
for i in audio_data:
    if i != 0:
        print(i)
with open(fpath[:-3]+"wav", 'wb') as f:
    f.write(audio_data)
Audio(audio_data,rate=44100)



4
168132
0.000375234521575985
0.00075046904315197
0.000375234521575985
0.001125703564727955
0.00075046904315197
0.00150093808630394
0.001876172607879925
0.00300187617260788
0.004127579737335835
0.004127579737335835
0.004878048780487805
0.004878048780487805
0.00525328330206379
0.005628517823639775
0.00675422138836773
0.007129455909943715
0.007879924953095686
0.00825515947467167
0.00825515947467167
0.00975609756097561
0.01125703564727955
0.01350844277673546
0.0150093808630394
0.01575984990619137
0.015384615384615385
0.015384615384615385
0.0150093808630394
0.0150093808630394
0.0150093808630394
0.01575984990619137
0.01651031894934334
0.01801125703564728
0.02026266416510319
0.02326454033771107
0.02551594746716698
0.02551594746716698
0.022138836772983114
0.016885553470919325
0.012382739212007506
0.01275797373358349
0.01951219512195122
0.028893058161350845
0.03527204502814259
0.03677298311444653
0.03452157598499062
0.03189493433395872
0.0300187617260788
0.028893058161350845
0.0240150093808630