In [None]:
# %pip install -U torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117
# %pip install miditok
# %pip install deepspeed==0.9.2
# %pip install lightning==2.0.2
# %pip install torchtoolkit
# %pip install git+https://github.com/huggingface/transformers
# %pip install git+https://github.com/huggingface/accelerate
# %pip install git+https://github.com/huggingface/evaluate
# %pip install tqdm
# %pip install wandb
# %pip install gdown

In [None]:
import deepspeed
import numpy as np
import random
import sys
import os
import torch
import gc
import json
import datetime
import lightning.pytorch as pl
from pathlib import Path
from torchtoolkit.data import create_subsets
from collections import namedtuple
from sympy import randprime

gc.collect()
torch.cuda.empty_cache()

seed = random.randint(1000, 10000)

pl.seed_everything(seed)

np.set_printoptions(precision=4, suppress=True, linewidth=200)

precision = 'bf16'

os.environ['RWKV_JIT_ON'] = '0'
os.environ['RWKV_FLOAT_MODE'] = precision
os.environ['RWKV_T_MAX'] = str(1024)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

os.chdir('/home/nico/dev/projects/ai/midigpt/rwkv/')

from model import RWKV

In [None]:
sys.path.append('./')

os.getcwd()

In [None]:
from miditok import REMIPlus, MMM
from miditok.constants import ADDITIONAL_TOKENS

PROJ_NAME = 'drums'
IS_BPE = True
TOKENS_PATH = f"/home/nico/data/ai/models/midi/{PROJ_NAME}{'/bpe' if IS_BPE else ''}"

Path(f'./out/{PROJ_NAME}').mkdir(parents=True, exist_ok=True)

BINS_VELOCITY = (24)
BINS_TEMPO = (24)

additional_tokens = ADDITIONAL_TOKENS
additional_tokens['Chord'] = True
additional_tokens['TimeSignature'] = True
additional_tokens['Program'] = True
additional_tokens['nb_tempos'] = BINS_TEMPO
tokenizer = MMM(
    additional_tokens=additional_tokens, 
    params=f'{TOKENS_PATH}/token_params.cfg',
    nb_velocities=BINS_VELOCITY
)

ORIG_VOCAB_SIZE = len(tokenizer.vocab)
BPE_VOCAB_SIZE = int(ORIG_VOCAB_SIZE * 1.25)

(ORIG_VOCAB_SIZE, BPE_VOCAB_SIZE, len(tokenizer))

In [None]:
from midi_dataset import MIDIDataset, DataCollatorGen

CTX_LEN = 2048

midi_jsons = list(Path(TOKENS_PATH).glob('*.json'))

random.shuffle(midi_jsons)

midi_dataset = MIDIDataset(
    files_paths=midi_jsons,
    min_seq_len=16,
    max_seq_len=CTX_LEN,
    no_labels=False
)

subset_train, subset_valid = create_subsets(midi_dataset, [0.3])

In [None]:
BATCHES = 6
N_EMBED = 512
N_LAYER = 12
MAGIC_PRIME = randprime(1000000, 10000000000)
SUBSET_NPY = f'out/{PROJ_NAME}/subset_train.npy'
EPOCHS = 6
EPOCH_STEPS = 1000
LR_RATE = 8e-4
LR_DECAY = 5e-6

params = {
    'accelerator': 'gpu',
    'adam_eps': 1e-8,
    'betas': (.9, .99),
    'ctx_len': int(os.environ['RWKV_T_MAX']),
    'data_file': SUBSET_NPY,
    'data_type': 'numpy',
    'devices': 1,
    'dim_att': N_EMBED,
    'dim_ffn': N_EMBED*4,
    'ds_bucket_mb': 200,
    'eight_bits': False,
    'epoch_begin': 0,
    'epoch_count': EPOCHS,
    'epoch_save': 1,
    'epoch_steps': EPOCH_STEPS,
    'grad_cp': 0, # model.py:530
    'gradient_clip_val': 1.0,
    'head_qk': int(N_EMBED*2),
    'layerwise_lr': 1,
    'log_every_n_steps': 10,
    'lr_final': LR_RATE/80,
    'lr_init': LR_RATE,
    'magic_prime': MAGIC_PRIME,
    'micro_bsz': BATCHES,
    'my_exit': 99999999,
    'my_pile_edecay': LR_DECAY,
    'my_pile_stage': 0,
    'my_pos_emb': 0,
    'my_qa_mask': 0,
    'my_random_steps': 0,
    'my_timestamp': datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S"),
    'n_embd': N_EMBED,
    'n_layer': N_LAYER,
    'padding_idx': 0,
    'pre_ffn': 0,
    'proj_dir': f'out/{PROJ_NAME}',
    'real_bsz':  BATCHES,
    'strategy': 'ddp_find_unused_parameters_false',
    'tiny_att_dim': -1,#int(N_EMBED/4),# model.py:406
    'tiny_att_layer': -1,# model.py:406
    'vocab_size': BPE_VOCAB_SIZE if IS_BPE else ORIG_VOCAB_SIZE,
    'wandb': '',
    'warmup_steps': 10,
}

params_obj = namedtuple('RWKVParams', params.keys())(*params.values())

In [None]:
model_base = RWKV(params_obj)
model_base.to(device)

In [None]:
ids = []
for st in subset_train + subset_valid:
    ids += list(st['input_ids'].numpy())

np.save(SUBSET_NPY, ids, allow_pickle=False)

In [None]:
len(set(np.load(SUBSET_NPY)))

In [None]:
from dataset import MyDataset
from torch.utils.data import DataLoader

train_data = MyDataset(params_obj)
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=params_obj.micro_bsz, num_workers=4, persistent_workers=False, drop_last=True)

In [None]:
import trainer
from importlib import reload

reload(trainer)

trainer_params = {
    'gradient_clip_val': 1.0,
    'log_every_n_steps': 100,
    'devices': 'auto',
    'max_steps': EPOCH_STEPS*EPOCHS,
    'accelerator': 'gpu',
    'strategy': 'auto',
    'enable_checkpointing': True,
    'precision': '16',
    'callbacks': [trainer.train_callback(params_obj)],
}
trainer_pl = pl.Trainer(**trainer_params)

In [None]:
# os.environ['RANK'] = '0'
# os.environ['WORLD_SIZE'] = '4'
# os.environ['MASTER_ADDR'] = 'desktop'
# os.environ['MASTER_PORT'] = '7777'

# from torch.distributed import launch

# torch.distributed.init_process_group()

if "deepspeed" in trainer_params:
    trainer_pl.strategy.config["zero_optimization"]["allgather_bucket_size"] = trainer_params.ds_bucket_mb * 1000 * 1000
    trainer_pl.strategy.config["zero_optimization"]["reduce_bucket_size"] = trainer_params.ds_bucket_mb * 1000 * 1000

trainer_pl.fit(model_base, data_loader)
