# Bachelor Graduation Project Part 2: Music Generation with X-Transformers


Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools


# (SETUP ENVIRONMENT)

In [1]:
# @title NVIDIA GPU Check
!nvidia-smi

Tue May 14 09:24:20 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   31C    P0              40W / 400W |      2MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
#@title Install all dependencies (run only once per session)
!git clone https://github.com/asigalov61/tegridy-tools
!pip install einops
!pip install torch-summary

Cloning into 'tegridy-tools'...
remote: Enumerating objects: 3644, done.[K
remote: Counting objects: 100% (402/402), done.[K
remote: Compressing objects: 100% (165/165), done.[K
remote: Total 3644 (delta 330), reused 275 (delta 235), pack-reused 3242[K
Receiving objects: 100% (3644/3644), 149.84 MiB | 12.05 MiB/s, done.
Resolving deltas: 100% (2387/2387), done.
Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0
Collecting torch-summary
  Downloading torch_summary-1.4.5-py3-none-any.whl (16 kB)
Installing collected packages: torch-summary
Successfully installed torch-summary-1.4.5


In [3]:
#@title Import all needed modules

print('Loading modules...')

import os
import pickle
import secrets
import statistics
import tqdm
import math
import copy
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import matplotlib.pyplot as plt

from torchsummary import summary
from sklearn import metrics

%cd /content/tegridy-tools/tegridy-tools/

import TMIDIX

%cd /content/tegridy-tools/tegridy-tools/X-Transformer

from x_transformer_1_23_2 import *

torch.set_float32_matmul_precision('high')
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn

%cd /content/

if not os.path.exists('/content/Dataset'):
    os.makedirs('/content/Dataset')

if not os.path.exists('/content/INTS'):
    os.makedirs('/content/INTS')

import random

from joblib import Parallel, delayed, parallel_config

print('PyTorch version:', torch.__version__)
print('Done')

Loading modules...
/content/tegridy-tools/tegridy-tools
/content/tegridy-tools/tegridy-tools/X-Transformer
/content
PyTorch version: 2.2.1+cu121
Done


# (DOWNLOAD AND UNZIP MIDI DATASET)

In [4]:
# @title Download and unzip Maestro v2 MIDI Dataset
%cd /content/Dataset
!wget https://github.com/asigalov61/Tegridy-MIDI-Dataset/raw/master/Misc/POP909-Piano-Violin-CC-BY-NC-SA.zip
!unzip POP909-Piano-Violin-CC-BY-NC-SA.zip
!rm POP909-Piano-Violin-CC-BY-NC-SA.zip
%cd /content/

/content/Dataset
--2024-05-14 09:24:54--  https://github.com/asigalov61/Tegridy-MIDI-Dataset/raw/master/Misc/POP909-Piano-Violin-CC-BY-NC-SA.zip
Resolving github.com (github.com)... 20.205.243.166
Connecting to github.com (github.com)|20.205.243.166|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/asigalov61/Tegridy-MIDI-Dataset/master/Misc/POP909-Piano-Violin-CC-BY-NC-SA.zip [following]
--2024-05-14 09:24:55--  https://raw.githubusercontent.com/asigalov61/Tegridy-MIDI-Dataset/master/Misc/POP909-Piano-Violin-CC-BY-NC-SA.zip
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 16922099 (16M) [application/zip]
Saving to: ‘POP909-Piano-Violin-CC-BY-NC-SA.zip’


2024-05-14 09:24:57 (337 MB/s) - ‘POP909-P

# (MIDI PROCESSOR)

In [5]:
#@title Load TMIDIX MIDI Processor

print('=' * 70)
print('Loading TMIDIX MIDI Processor...')
print('=' * 70)

def TMIDIX_MIDI_Processor(midi_file):

    melody_chords = []

    try:

        fn = os.path.basename(midi_file)

        #=======================================================
        # START PROCESSING

        raw_score = TMIDIX.midi2single_track_ms_score(midi_file)

        escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]

        escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=16)

        all_scores = []

        for ta in range(0, 6, 2):
          for pa in range(-1, 2):

            escore_notes_aug = copy.deepcopy(escore_notes)

            for e in escore_notes_aug:
              e[1] += ta
              e[2] += ta
              e[4] += pa

            violin_dscore = [d for d in escore_notes_aug if d[6] == 40]
            violin_mono_melody_score = [n[0] for n in TMIDIX.chordify_score([1000, violin_dscore])]

            piano_dscore = [d for d in escore_notes_aug if d[6] == 0]

            violin_piano_score_notes = sorted(violin_mono_melody_score+piano_dscore, key=lambda x: x[1])

            violin_piano_dscore_notes = [d[1:] for d in TMIDIX.delta_score_notes(violin_piano_score_notes,
                                                                                even_timings=True,
                                                                                compress_timings=True)]

            all_scores.append(violin_piano_dscore_notes)

        return all_scores

    except Exception as e:
      print('Error!')
      print('Exception', e)
      return None

print('Done!')
print('=' * 70)

Loading TMIDIX MIDI Processor...
Done!


# (FILE LIST)

In [6]:
#@title Save file list
###########

print('=' * 70)
print('Loading MIDI files...')
print('This may take a while on a large dataset in particular.')

dataset_addr = "/content/Dataset"

# os.chdir(dataset_addr)
filez = list()
for (dirpath, dirnames, filenames) in os.walk(dataset_addr):
    filez += [os.path.join(dirpath, file) for file in filenames]
print('=' * 70)

if not filez:
    print('Could not find any MIDI files. Please check Dataset dir...')
    print('=' * 70)

else:
  print('Randomizing file list...')
  random.shuffle(filez)
  print('Done!')
  print('=' * 70)
  print('Total files:', len(filez))
  print('=' * 70)

Loading MIDI files...
This may take a while on a large dataset in particular.
Randomizing file list...
Done!
Total files: 2085


# (PROCESS MIDIs)

In [7]:
#@title Process MIDIs with TMIDIX MIDI processor

print('=' * 70)
print('TMIDIX MIDI Processor')
print('=' * 70)
print('Starting up...')
print('=' * 70)

###########

melody_chords_f = []

print('Processing MIDI files. Please wait...')
print('=' * 70)

for i in tqdm.tqdm(range(0, len(filez), 16)):

  with parallel_config(backend='threading', n_jobs=16, verbose = 0):

    output = Parallel()(delayed(TMIDIX_MIDI_Processor)(f) for f in filez[i:i+16])

    for o in output:

        if o is not None:
            melody_chords_f.append(o)

print('Done!')
print('=' * 70)

TMIDIX MIDI Processor
Starting up...
Processing MIDI files. Please wait...


 65%|██████▍   | 85/131 [06:18<03:30,  4.58s/it]

Error!
Exception 'NoneType' object is not iterable


 94%|█████████▍| 123/131 [09:05<00:33,  4.13s/it]

Error!
Exception 'NoneType' object is not iterable


100%|██████████| 131/131 [09:37<00:00,  4.41s/it]

Done!





# (SAVE/LOAD PROCESSED MIDIs)

In [8]:
# @title Save processed MIDIs
TMIDIX.Tegridy_Any_Pickle_File_Writer(melody_chords_f, '/content/Processed_MIDIs')

Tegridy Pickle File Writer
Creating new Dataset file...
Dataset was saved as: /content/Processed_MIDIs.pickle
Task complete. Enjoy! :)


In [9]:
# @title Load processed MIDIs
melody_chords_f = TMIDIX.Tegridy_Any_Pickle_File_Reader('/content/Processed_MIDIs')
print('Done!')

Tegridy Pickle File Loader
Loading the pickle file. Please wait...
Done!


# (TEST PROCESSED MIDIs)

In [10]:
#@title Test Processed MIDIs

train_data1 = random.choice(melody_chords_f[0])

#train_data1 = max(melody_chords_f, key = len)

print('Sample data:', train_data1[:15])

out = train_data1

patches = [0] * 16
patches[3] = 40

if len(out) != 0:

    song = out
    song_f = []

    time = 0
    dur = 0
    vel = 90
    pitch = 0
    channel = 0

    for s in song:

        time += s[0]
        dur = s[1]
        channel = s[2]
        pitch = s[3]
        vel = s[4]

        song_f.append(['note', time, dur, channel, pitch, vel ])

detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
                                                          output_signature = 'Tiny Music Transformer',
                                                          output_file_name = '/content/Tiny-Music-Transformer-Composition',
                                                          track_name='Project Los Angeles',
                                                          list_of_MIDI_patches=patches,
                                                          timings_multiplier=32
                                                          )

print('Done!')

Sample data: [[0, 10, 3, 85, 108, 40], [0, 22, 0, 65, 70, 0], [0, 20, 0, 61, 67, 0], [0, 32, 0, 49, 73, 0], [12, 7, 3, 73, 111, 40], [0, 19, 0, 56, 48, 0], [6, 5, 3, 75, 105, 40], [7, 10, 3, 77, 105, 40], [0, 11, 0, 65, 90, 0], [0, 12, 0, 61, 73, 0], [12, 10, 3, 82, 100, 40], [13, 44, 3, 80, 108, 40], [0, 17, 0, 60, 23, 0], [0, 17, 0, 56, 49, 0], [0, 37, 0, 41, 44, 0]]
Converting to MIDI. Please stand-by...
Done! Enjoy! :)
Done!


# (PREP INTs)

In [11]:
# @title Convert processed MIDIs to INTs
SEQ_LEN = 8192
PAD_IDX = 835

print('=' * 70)

train_data = []

for m in tqdm.tqdm(melody_chords_f):
  for mm in m:

    cscore = TMIDIX.chordify_score(mm)

    dat = [834]

    first_chord = True

    for chord in cscore:

      tones_chord = sorted(set([c[3] % 12 for c in chord]))

      try:
          chord_token = TMIDIX.ALL_CHORDS_SORTED.index(tones_chord) + 512
      except:
          checked_tones_chord = TMIDIX.check_and_fix_tones_chord(tones_chord)
          chord_token = TMIDIX.ALL_CHORDS_SORTED.index(checked_tones_chord) + 512

      dat.append(chord_token)

      if first_chord:
        dat.append(0)
        first_chord = False

      for c in chord:

            if c[2] == 0:
                cha = 0
            if c[2] == 3:
                cha = 1

            if c[0] != 0:
              dat.extend([c[0], c[1]+128, ((cha * 128) + c[3])+256])
            else:
              dat.extend([c[1]+128, ((cha * 128) + c[3])+256])

    dat = dat[:SEQ_LEN+1]
    dat += [PAD_IDX] * (SEQ_LEN+1 - len(dat))

    train_data.append(dat)

# Total dict size 644

random.shuffle(train_data)

print('Done!')
print('=' * 70)
print(len(train_data), max(train_data, key=len) == min(train_data, key=len))
print('=' * 70)
print(len(max(train_data, key=len)), len(min(train_data, key=len)))
print('=' * 70)
print(train_data[0][:15])
print('=' * 70)



100%|██████████| 2083/2083 [01:52<00:00, 18.53it/s]

Done!
18747 True
8193 8193
[834, 546, 0, 139, 320, 199, 316, 199, 313, 199, 309, 199, 301, 779, 19]





# (SAVE/LOAD INTs)

In [12]:
# @title Save INTs
TMIDIX.Tegridy_Any_Pickle_File_Writer(train_data, '/content/Training_INTs')

Tegridy Pickle File Writer
Creating new Dataset file...
Dataset was saved as: /content/Training_INTs.pickle
Task complete. Enjoy! :)


In [13]:
# @title Load INTs
train_data = TMIDIX.Tegridy_Any_Pickle_File_Reader('/content/Training_INTs')
print('Done!')

Tegridy Pickle File Loader
Loading the pickle file. Please wait...
Done!


# (PREP MODEL)

In [14]:
# @title Setup and init the model

# constants

SEQ_LEN = 8192 # Models seq len
PAD_IDX = 835 # Models pad index

BATCH_SIZE = 4
NUM_EPOCHS = 200
GRADIENT_ACCUMULATE_EVERY = 4

LEARNING_RATE = 1e-4

VALIDATE_EVERY  = 100
SAVE_EVERY = 500
GENERATE_EVERY  = 100
PRINT_STATS_EVERY = 20

GENERATE_LENGTH = 32

# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

# instantiate the model
model = TransformerWrapper(
    num_tokens = PAD_IDX+1,
    max_seq_len = SEQ_LEN,
    attn_layers = Decoder(
        dim = 512,
        depth = 8,
        heads = 8
    )
).cuda()
'''
model = TransformerWrapper(
    num_tokens = PAD_IDX+1,
    max_seq_len = SEQ_LEN,
    attn_layers = Decoder(dim = 1024,
                          depth = 8,
                          heads = 8,
                          attn_flash = True,
                          use_rmsnorm = True, # set to true to use for all layers)
                          rotary_pos_emb = True  # turns on rotary positional embedding
    ))
'''
model = AutoregressiveWrapper(model, ignore_index=PAD_IDX)


print('Done!')

summary(model)

# Dataloader

class MusicDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):

        full_seq = torch.Tensor(self.data[index][:self.seq_len+1]).long()

        return full_seq.cuda()

    def __len__(self):
        return (len(self.data) // BATCH_SIZE) * BATCH_SIZE

# precision/optimizer/scaler

dtype = torch.float16

ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

scaler = torch.cuda.amp.GradScaler()

Done!
Layer (type:depth-idx)                        Param #
├─TransformerWrapper: 1-1                     --
|    └─TokenEmbedding: 2-1                    --
|    |    └─Embedding: 3-1                    428,032
|    └─AbsolutePositionalEmbedding: 2-2       --
|    |    └─Embedding: 3-2                    4,194,304
|    └─Identity: 2-3                          --
|    └─Dropout: 2-4                           --
|    └─Identity: 2-5                          --
|    └─Encoder: 2-6                           --
|    |    └─ModuleList: 3-3                   25,202,688
|    |    └─LayerNorm: 3-4                    1,024
|    └─Linear: 2-7                            428,868
Total params: 30,254,916
Trainable params: 30,254,916
Non-trainable params: 0


# (TRAIN MODEL)

In [15]:
# @title Train the model

train_losses = []
val_losses = []

train_accs = []
val_accs = []

nsteps = 0

for ep in range(NUM_EPOCHS):

  print('=' * 70)
  print('Epoch #', ep)
  print('=' * 70)

  random.shuffle(train_data)

  train_dataset = MusicDataset(train_data, SEQ_LEN)
  val_dataset   = MusicDataset(train_data, SEQ_LEN)
  train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
  val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

  NUM_BATCHES = len(train_data) // BATCH_SIZE // GRADIENT_ACCUMULATE_EVERY

  for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='Training'):
      model.train()

      for __ in range(GRADIENT_ACCUMULATE_EVERY):
          with ctx:
              loss, acc = model(next(train_loader))
          loss = loss / GRADIENT_ACCUMULATE_EVERY
          scaler.scale(loss).backward(torch.ones(loss.shape).cuda())

      if i % PRINT_STATS_EVERY == 0:
          print(f'Training loss: {loss.mean().item() * GRADIENT_ACCUMULATE_EVERY}')
          print(f'Training acc: {acc.mean().item()}')

      train_losses.append(loss.mean().item() * GRADIENT_ACCUMULATE_EVERY)
      train_accs.append(acc.mean().item())

      scaler.unscale_(optim)
      torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
      scaler.step(optim)
      scaler.update()
      optim.zero_grad(set_to_none=True)

      nsteps += 1

      if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
          with ctx:
            val_loss, val_acc = model(next(val_loader))

            print(f'Validation loss: {val_loss.mean().item()}')
            print(f'Validation acc: {val_acc.mean().item()}')

            val_losses.append(val_loss.mean().item())
            val_accs.append(val_acc.mean().item())

            print('Plotting training loss graph...')

            tr_loss_list = train_losses
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            plt.close()
            print('Done!')

            print('Plotting training acc graph...')

            tr_loss_list = train_accs
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            plt.close()
            print('Done!')

            print('Plotting validation loss graph...')
            tr_loss_list = val_losses
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            plt.close()
            print('Done!')

            print('Plotting validation acc graph...')
            tr_loss_list = val_accs
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            plt.close()
            print('Done!')

      if i % GENERATE_EVERY == 0:
        model.eval()

        inp = random.choice(val_dataset)[:512]

        print(inp)

        with ctx:

            sample = model.generate(inp[None, ...], GENERATE_LENGTH)

        print(sample)

      if i % SAVE_EVERY == 0:

          print('Saving model progress. Please wait...')
          print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth')

          fname = '/content/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth'

          torch.save(model.state_dict(), fname)

          data = [train_losses, train_accs, val_losses, val_accs]

          TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/content/losses_accs')

          print('Done!')

#======================================================================================================

print('Saving model progress. Please wait...')
print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth')

fname = '/content/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth'

torch.save(model.state_dict(), fname)

print('Done!')

data = [train_losses, train_accs, val_losses, val_accs]

TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/content/losses_accuracies')

# Save training loss graph

plt.plot([i for i in range(len(train_losses))] ,train_losses, 'b')
plt.savefig('/content/training_loss_graph.png')
plt.close()
print('Done!')

# Save training acc graph

plt.plot([i for i in range(len(train_accs))] ,train_accs, 'b')
plt.savefig('/content/training_acc_graph.png')
plt.close()
print('Done!')

# Save validation loss graph

plt.plot([i for i in range(len(val_losses))] ,val_losses, 'b')
plt.savefig('/content/validation_loss_graph.png')
plt.close()
print('Done!')

# Save validation acc graph

plt.plot([i for i in range(len(val_accs))] ,val_accs, 'b')
plt.savefig('/content/validation_acc_graph.png')
plt.close()
print('Done!')

Epoch # 0


Training:   0%|          | 0/1171 [00:01<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 4.00 GiB. GPU 0 has a total capacity of 39.56 GiB of which 2.41 GiB is free. Process 32200 has 37.14 GiB memory in use. Of the allocated memory 33.30 GiB is allocated by PyTorch, and 3.35 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

# (SAVE MODEL)

In [None]:
# @title Manual save
print('Saving model progress. Please wait...')
print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth')

fname = '/content/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth'

torch.save(model.state_dict(), fname)

print('Done!')

data = [train_losses, train_accs, val_losses, val_accs]

TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/content/losses_accuracies')

# Save training loss graph

plt.plot([i for i in range(len(train_losses))] ,train_losses, 'b')
plt.savefig('/content/training_loss_graph.png')
plt.close()
print('Done!')

# Save training acc graph

plt.plot([i for i in range(len(train_accs))] ,train_accs, 'b')
plt.savefig('/content/training_acc_graph.png')
plt.close()
print('Done!')

# Save validation loss graph

plt.plot([i for i in range(len(val_losses))] ,val_losses, 'b')
plt.savefig('/content/validation_loss_graph.png')
plt.close()
print('Done!')

# Save validation acc graph

plt.plot([i for i in range(len(val_accs))] ,val_accs, 'b')
plt.savefig('/content/validation_acc_graph.png')
plt.close()
print('Done!')

# (EVAL MODEL)

In [None]:
# @title Eval model
dtype = 'float16'
device_type = 'cuda'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)

model.eval()

x = torch.tensor(random.choice(train_data)[:2048], dtype=torch.long, device='cuda')[None, ...]
#x = torch.tensor([[834]] * 1, dtype=torch.long, device='cuda')


# run generation

with ctx:
    out = model.generate(x,
                        1024,
                        temperature=0.9,
                        return_prime=False,
                        verbose=True)

y = out.tolist()

print('---------------')
print(y[0])

In [None]:
#@title Convert output INTs to MIDI

train_data1 = y[0] # batch number goes here

print('Sample INTs', train_data1[:15])

out = train_data1

patches = [0] * 16
patches[3] = 40

if len(out) != 0:

    song = out
    song_f = []

    time = 0
    dur = 0
    vel = 90
    pitch = 0
    channel = 0

    for ss in song:

        if 0 <= ss < 128:

            time += ss

        if 128 <= ss < 256:

            dur = ss-128

        if 256 <= ss < 512:

            pitch = (ss-256) % 128

            channel = (ss-256) // 128

            if channel == 1:
              channel = 3
              vel = 110
            else:
              channel = 0
              vel = 80

            song_f.append(['note', time, dur, channel, pitch, vel ])

detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
                                                          output_signature = 'Tiny Music Transformer',
                                                          output_file_name = '/content/Tiny-Music-Transformer-Composition',
                                                          track_name='Project Los Angeles',
                                                          list_of_MIDI_patches=patches,
                                                          timings_multiplier=32
                                                          )

print('Done!')

# (PLOT TOKENS EMBEDDINGS)

In [None]:
# @title Plot model tokens embeddings
tok_emb = model.net.token_emb.emb.weight.detach().cpu().tolist()

cos_sim = metrics.pairwise_distances(
  tok_emb, metric='cosine'
)
plt.figure(figsize=(7, 7))
plt.imshow(cos_sim, cmap="inferno", interpolation="nearest")
im_ratio = cos_sim.shape[0] / cos_sim.shape[1]
plt.colorbar(fraction=0.046 * im_ratio, pad=0.04)
plt.xlabel("Position")
plt.ylabel("Position")
plt.tight_layout()
plt.plot()
plt.savefig("/content/Tiny-Music-Transformer-Tokens-Embeddings-Plot.png", bbox_inches="tight")

# Congrats! You did it! :)