In [1]:
#@title NVIDIA GPU check
!nvidia-smi

Mon Jan  8 16:49:36 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   44C    P8               9W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
#@title clone the model repo
!git clone --depth 1 https://github.com/asigalov61/Tiny-Music-Transformer

Cloning into 'Tiny-Music-Transformer'...
remote: Enumerating objects: 29, done.[K
remote: Counting objects: 100% (29/29), done.[K
remote: Compressing objects: 100% (28/28), done.[K
remote: Total 29 (delta 0), reused 0 (delta 0), pack-reused 0[K
Receiving objects: 100% (29/29), 1.01 MiB | 9.43 MiB/s, done.


In [22]:
# Install dependencies

!pip install huggingface_hub torch einops torch-summary tqdm matplotlib
!apt install fluidsynth

# Import modules
import os
import random
import torch
from torch.utils.data import DataLoader, Dataset
from torchsummary import summary
import matplotlib.pyplot as plt
import tqdm
from huggingface_hub import hf_hub_download

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
fluidsynth is already the newest version (2.2.5-1).
0 upgraded, 0 newly installed, 0 to remove and 24 not upgraded.


In [4]:
%cd /content/Tiny-Music-Transformer
from x_transformer_1_23_2 import TransformerWrapper, AutoregressiveWrapper, Decoder
import TMIDIX
#from midi_to_colab_audio import midi_to_colab_audio

/content/Tiny-Music-Transformer


In [5]:
#@title Load TMIDIX MIDI Processor

print('Loading TMIDIX MIDI Processor...')

def TMIDIX_MIDI_Processor(midi_file):

    melody_chords = []

    try:
        fn = os.path.basename(midi_file)

        # Filtering out GIANT4 MIDIs
        file_size = os.path.getsize(midi_file)

        if file_size <= 1000000:

          score = TMIDIX.midi2single_track_ms_score(open(midi_file, 'rb').read(),
                                          recalculate_channels=False)

          enhanced_score = TMIDIX.advanced_score_processor(score,
                                                          return_score_analysis=False,
                                                          return_enhanced_score_notes=True)[0]

          if len(enhanced_score) > 0:
              if min([e[1] for e in enhanced_score]) >= 0 and min([e[2] for e in enhanced_score]) >= 0:

                  if len(enhanced_score) > 0:

                      for e in enhanced_score:
                        e[1] = int(e[1] / 16)
                        e[2] = int(e[2] / 16)

                      # Sorting by patch, pitch, then by start-time

                      enhanced_score.sort(key=lambda x: x[6])
                      enhanced_score.sort(key=lambda x: x[4], reverse=True)
                      enhanced_score.sort(key=lambda x: x[1])

                      pe = enhanced_score[0]

                      notes = []

                      for e in enhanced_score:

                        time = max(0, min(127, (e[1] - pe[1])))
                        dur = max(0, min(127, e[2]))
                        cha = max(0, min(15, e[3]))
                        ptc = max(1, min(127, e[4]))

                        notes.append([time, dur, cha, ptc])

                        pe = e

                      return notes

    except Exception as e:
      print('Error!')
      print('Exception', e)
      return None

print('Done!')

Loading TMIDIX MIDI Processor...
Done!


In [6]:
!wget 'https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip'
!unzip 'maestro-v3.0.0-midi.zip' -d '/content/Tiny-Music-Transformer/maestro_dataset_v3'

--2024-01-08 16:58:32--  https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 108.177.119.207, 108.177.127.207, 172.217.218.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|108.177.119.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 58416533 (56M) [application/octet-stream]
Saving to: ‘maestro-v3.0.0-midi.zip’


2024-01-08 16:58:34 (28.6 MB/s) - ‘maestro-v3.0.0-midi.zip’ saved [58416533/58416533]

Archive:  maestro-v3.0.0-midi.zip
  inflating: /content/Tiny-Music-Transformer/maestro_dataset_v3/maestro-v3.0.0/2004/MIDI-Unprocessed_XP_08_R1_2004_01-02_ORIG_MID--AUDIO_08_R1_2004_01_Track01_wav.midi  
  inflating: /content/Tiny-Music-Transformer/maestro_dataset_v3/maestro-v3.0.0/2004/MIDI-Unprocessed_XP_09_R1_2004_05_ORIG_MID--AUDIO_09_R1_2004_06_Track06_wav.midi  
  inflating: /content/Tiny-Music-Transformer/maestro_dataset_v3/maestro-v3.0.

In [11]:
# Load and process MIDI files
dataset_addr = "/content/Tiny-Music-Transformer/maestro_dataset_v3/maestro-v3.0.0"
filez = [os.path.join(dirpath, file) for dirpath, _, filenames in os.walk(dataset_addr) for file in filenames]


In [12]:
len(filez)

1280

In [16]:
%cd /content/

if not os.path.exists('/content/INTS'):
    os.makedirs('/content/INTS')

import random
from joblib import Parallel, delayed, parallel_config

melody_chords_f = []

print('Processing MIDI files. Please wait...')
print('=' * 70)

for i in tqdm.tqdm(range(0, len(filez), 16)):

  with parallel_config(backend='threading', n_jobs=16, verbose = 0):

    output = Parallel()(delayed(TMIDIX_MIDI_Processor)(f) for f in filez[i:i+16])

    for o in output:

        if o is not None:
            melody_chords_f.append(o)

/content
Processing MIDI files. Please wait...


  0%|          | 0/80 [00:00<?, ?it/s]

Error!
Exception list index out of range
Error!
Exception list index out of rangeError!
Exception
Error!
Exception list index out of range
 list index out of range


midi2opus: midi starts with b'Plea' instead of 'MThd'
midi2opus: midi starts with b'Attr' instead of 'MThd'
midi2opus: midi starts with b'{"ca' instead of 'MThd'
midi2opus: midi starts with b'cano' instead of 'MThd'
100%|██████████| 80/80 [15:57<00:00, 11.96s/it]


In [17]:
# @title Convert processed MIDIs to INTs
SEQ_LEN = 8192
PAD_IDX = 643

train_data = []

for m in tqdm.tqdm(melody_chords_f):

    if m[0][2] == 0:
        cha = 0
    if m[0][2] == 3:
        cha = 1

    dat = [642, 512+cha, m[0][3]+514, 0]

    for mm in m:

        if mm[2] == 0:
            cha = 0
        if mm[2] == 3:
            cha = 1

        if mm[0] != 0:
           dat.extend([mm[0], mm[1]+128, ((cha * 128) + mm[3])+256])
        else:
           dat.extend([mm[1]+128, ((cha * 128) + mm[3])+256])

    dat = dat[:SEQ_LEN+1]
    dat += [PAD_IDX] * (SEQ_LEN+1 - len(dat))

    train_data.append(dat)

random.shuffle(train_data)

print('Done!')

100%|██████████| 1276/1276 [00:04<00:00, 282.07it/s]

Done!





In [18]:
# @title Save and Load INTs
TMIDIX.Tegridy_Any_Pickle_File_Writer(train_data, '/content/Training_INTs_maestro')

train_data = TMIDIX.Tegridy_Any_Pickle_File_Reader('/content/Training_INTs_maestro')
print('Done!')

Tegridy Pickle File Writer
Creating new Dataset file...
Dataset was saved as: /content/Training_INTs_maestro.pickle
Task complete. Enjoy! :)
Tegridy Pickle File Loader
Loading the pickle file. Please wait...
Done!


In [23]:
#@title Load Tiny Music Transformer Pre-Trained Model

select_model_to_load = "139M-32L-Very-Fast-Tiny"

model_precision = "bfloat16"

plot_tokens_embeddings = "None"

full_path_to_models_dir = "/content/Tiny-Music-Transformer/Model/"

print('Loading Tiny Music Transformer Training Data...')
print('Please wait...')

training_data_file = 'Tiny_Music_Transformer_Train_Data_Mono_Melodies_Piano_Violin_MIDI_Dataset.pickle'

if os.path.isfile(full_path_to_models_dir+training_data_file):
  print('Training data already exists...')

else:
  hf_hub_download(repo_id='asigalov61/Tiny-Music-Transformer',
                  filename=training_data_file,
                  local_dir='/content/Tiny-Music-Transformer/Model',
                  local_dir_use_symlinks=False)

train_data = TMIDIX.Tegridy_Any_Pickle_File_Reader('/content/Tiny-Music-Transformer/Model/Tiny_Music_Transformer_Train_Data_Mono_Melodies_Piano_Violin_MIDI_Dataset')
print('Done!')
print('Loading Tiny Music Transformer', select_model_to_load,'Pre-Trained Model...')

if select_model_to_load == '139M-32L-Very-Fast-Tiny':

  model_checkpoint_file_name = 'Tiny_Music_Transformer_Tiny_Trained_Model_10737_steps_0.4039_loss_0.8729_acc.pth'
  model_path = full_path_to_models_dir+model_checkpoint_file_name
  if os.path.isfile(model_path):
    print('Model already exists...')

  else:
    hf_hub_download(repo_id='asigalov61/Tiny-Music-Transformer',
                    filename=model_checkpoint_file_name,
                    local_dir='/content/Tiny-Music-Transformer/Model',
                    local_dir_use_symlinks=False)



print('=' * 70)
print('Instantiating model...')

torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda'

if model_precision == 'bfloat16' and torch.cuda.is_bf16_supported():
  dtype = 'bfloat16'
else:
  dtype = 'float16'

if model_precision == 'float16':
  dtype = 'float16'

ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)

SEQ_LEN = 8192

model = TransformerWrapper(
    num_tokens = 644,
    max_seq_len = SEQ_LEN,
    attn_layers = Decoder(dim = 512, depth = 32, heads = 16, attn_flash = True)
)

model = AutoregressiveWrapper(model, ignore_index=643)

model.cuda()
model.load_state_dict(torch.load(model_path))

model.eval()

print('Done!')

print('Model will use', dtype, 'precision...')

# Model stats
print('Model summary...')
summary(model)

Loading Tiny Music Transformer Training Data...
Please wait...
Training data already exists...
Tegridy Pickle File Loader
Loading the pickle file. Please wait...
Done!
Loading Tiny Music Transformer 139M-32L-Very-Fast-Tiny Pre-Trained Model...
Model already exists...
Instantiating model...
Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda
Done!
Model will use float16 precision...
Model summary...
Layer (type:depth-idx)                        Param #
├─TransformerWrapper: 1-1                     --
|    └─TokenEmbedding: 2-1                    --
|    |    └─Embedding: 3-1                    329,728
|    └─AbsolutePositionalEmbedding: 2-2       --
|    |    └─Embedding: 3-2                    4,194,304
|    └─Identity: 2-3                          --
|    └─Dropout: 2-4                           --
|    └─Identity: 2-5                          --
|    └─Decoder: 2-6                           --
|    |    └─ModuleList: 3-3                   134,365,1

Layer (type:depth-idx)                        Param #
├─TransformerWrapper: 1-1                     --
|    └─TokenEmbedding: 2-1                    --
|    |    └─Embedding: 3-1                    329,728
|    └─AbsolutePositionalEmbedding: 2-2       --
|    |    └─Embedding: 3-2                    4,194,304
|    └─Identity: 2-3                          --
|    └─Dropout: 2-4                           --
|    └─Identity: 2-5                          --
|    └─Decoder: 2-6                           --
|    |    └─ModuleList: 3-3                   134,365,184
|    |    └─LayerNorm: 3-4                    1,024
|    └─Linear: 2-7                            330,372
Total params: 139,220,612
Trainable params: 139,220,612
Non-trainable params: 0

In [24]:
#@title Model Parameters
SEQ_LEN = 8192 # Models seq len
PAD_IDX = 643 # Models pad index

BATCH_SIZE = 4
NUM_EPOCHS = 25
GRADIENT_ACCUMULATE_EVERY = 4

LEARNING_RATE = 2e-4

VALIDATE_EVERY  = 10
SAVE_EVERY = 25
GENERATE_EVERY  = 100
PRINT_STATS_EVERY = 20

GENERATE_LENGTH = 32

# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

In [25]:
class MusicDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):

        full_seq = torch.Tensor(self.data[index][:self.seq_len+1]).long()

        return full_seq.cuda()

    def __len__(self):
        return (len(self.data) // BATCH_SIZE) * BATCH_SIZE

In [26]:
#@title precision/optimizer/scaler

dtype = torch.float16
ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler(enabled=True)

In [27]:
import gc
torch.cuda.empty_cache()
gc.collect()

53

In [None]:
train_losses = []
val_losses = []

train_accs = []
val_accs = []

nsteps = 0

for ep in range(NUM_EPOCHS):

  print('Epoch #', ep)

  random.shuffle(train_data)

  train_dataset = MusicDataset(train_data, SEQ_LEN)
  val_dataset   = MusicDataset(train_data, SEQ_LEN)

  train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
  val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

  NUM_BATCHES = len(train_data) // BATCH_SIZE // GRADIENT_ACCUMULATE_EVERY

  for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='Training'):

      model.train()

      for __ in range(GRADIENT_ACCUMULATE_EVERY):
          with ctx:
              loss, acc = model(next(train_loader))
          loss = loss / GRADIENT_ACCUMULATE_EVERY
          scaler.scale(loss).backward(torch.ones(loss.shape).cuda())

      if i % PRINT_STATS_EVERY == 0:
          print(f'Training loss: {loss.mean().item() * GRADIENT_ACCUMULATE_EVERY}')
          print(f'Training acc: {acc.mean().item()}')

      train_losses.append(loss.mean().item() * GRADIENT_ACCUMULATE_EVERY)
      train_accs.append(acc.mean().item())

      scaler.unscale_(optim)
      torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
      scaler.step(optim)
      scaler.update()
      optim.zero_grad(set_to_none=True)

      nsteps += 1

      if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
          with ctx:
            val_loss, val_acc = model(next(val_loader))

            print(f'Validation loss: {val_loss.mean().item()}')
            print(f'Validation acc: {val_acc.mean().item()}')

            val_losses.append(val_loss.mean().item())
            val_accs.append(val_acc.mean().item())

            print('Plotting training loss graph...')

            tr_loss_list = train_losses
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            plt.close()
            print('Done!')

            print('Plotting validation loss graph...')
            tr_loss_list = val_losses
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            plt.close()
            print('Done!')

      if i % GENERATE_EVERY == 0:
        model.eval()

        inp = random.choice(val_dataset)[:-1]

        print(inp)

        with ctx:

            sample = model.generate(inp[None, ...], GENERATE_LENGTH)

        print(sample)

      if i % SAVE_EVERY == 0:

          print('Saving model progress. Please wait...')
          print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth')

          fname = '/content/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth'

          torch.save(model.state_dict(), fname)

          data = [train_losses, train_accs, val_losses, val_accs]

          TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/content/losses_accs')

          print('Done!')


print('Saving model progress. Please wait...')
print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth')

fname = '/content/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth'

torch.save(model.state_dict(), fname)

print('Done!')

data = [train_losses, train_accs, val_losses, val_accs]

TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/content/losses_accuracies')