# Text to Music Transformer

## Dependencies - run once

In [1]:
# !pip install miditok==2.1.8 torchmetrics
# !pip install miditoolkit 
# !pip install datasets
# !pip install transformers
# !pip install accelerate -U
# !pip install tensorrt
# !pip install tensorboardX
# !pip install fairseq

# !pip install utils
# !pip install torch
# !pip install unidecode

# !pip install ipywidgets
# !jupyter labextension install @jupyter-widgets/jupyterlab-manager

# # !pip install torch - takes 15-18 mins
# !pip install git+https://github.com/idiap/fast-transformers

# !pip install torch==2.0.0+cu117 torchvision==0.15.1+cu117 torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cu117

# Data Prep

In [1]:
#imports
import time
import argparse
import subprocess
from utils import *
from transformers import AutoTokenizer, RobertaConfig, RobertaModel
import torch
import numpy as np
import re
import os
import math
import requests
from tqdm import tqdm
from unidecode import unidecode
from transformers import AutoModel, AutoConfig, PreTrainedModel
from torch.utils.data import random_split
from torch.utils.data import Dataset
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from miditok import Octuple, TokenizerConfig
from miditoolkit import MidiFile
from fast_transformers.builders import TransformerEncoderBuilder
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DataParallel
from fast_transformers.masking import TriangularCausalMask

torch.manual_seed(5)

start_setup_time = time.time()

if torch.cuda.is_available():
    print('Using Cuda')
    device = torch.device("cuda")
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

def get_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
        
    return device
    
def get_filtered_files(directory_path, exclude_list):
    file_list = os.listdir(directory_path)
    filtered_list = [file_name for file_name in file_list if file_name not in exclude_list]
    return filtered_list

Using Cuda


In [2]:
# ------ New Constants ------
batch_size = 4
ngpus = 4

music_EOS_token = [[2,2,2,2,2,2]] # according to the vocab
music_SOS_token = [1,1,1,1,1,1] # according to the vocab

music_EOS_tensor = torch.tensor(music_EOS_token)


In [3]:
# Confirming size of dataset
print(len(get_filtered_files('dataset/midi_partial', ['.ipynb_checkpoints', '.DS_Store'])))
print(len(get_filtered_files('dataset/text_partial', ['.ipynb_checkpoints', '.DS_Store'])))

5000
5000


In [4]:
class Train_Dataset(Dataset):
    def __init__(self):
        self.text_files = get_filtered_files('dataset/text_partial', ['.ipynb_checkpoints', '.DS_Store'])
        self.midi_files = get_filtered_files('dataset/midi_partial', ['.ipynb_checkpoints', '.DS_Store'])
        self.text_tokenizer = AutoTokenizer.from_pretrained("roberta-base")
        self.midi_tokenizer = Octuple(TokenizerConfig(use_programs=True))

        # Calculate the number of files to use for training (90% of the total)
        self.train_size = int(0.8 * len(self.midi_files))
        self.text_files = self.text_files[:self.train_size]
        self.midi_files = self.midi_files[:self.train_size]
        
    def __len__(self):
        return len(self.midi_files)

    def __getitem__(self, idx):
        text_file = self.text_files[idx]
        with open('dataset/text_partial/' + text_file, 'r') as file:
            text = file.read()

        text = text.replace('\n',', ')
        text_return = self.text_tokenizer(text, return_tensors="pt")
        text_return = torch.squeeze(text_return['input_ids'], 0)

        midi_file = MidiFile('dataset/midi_partial/' + self.midi_files[idx])
        midi = torch.LongTensor(self.midi_tokenizer(midi_file))
        EOS_tensor = torch.LongTensor(music_EOS_token)
        midi_return = torch.cat([midi, EOS_tensor], dim = 0)
        
        return text_return, midi_return


class Val_Dataset(Dataset):
    def __init__(self):
        self.text_files = get_filtered_files('dataset/text_partial', ['.ipynb_checkpoints', '.DS_Store'])
        self.midi_files = get_filtered_files('dataset/midi_partial', ['.ipynb_checkpoints', '.DS_Store'])
        self.text_tokenizer = AutoTokenizer.from_pretrained("roberta-base")
        self.midi_tokenizer = Octuple(TokenizerConfig(use_programs=True))
        
        # Calculate the number of files to use for training (90% of the total)
        self.val_size = int(0.1 * len(self.midi_files))
        train_size = int(0.8 * len(self.midi_files))
        val_start_index = train_size
        val_end_index = train_size + self.val_size
        self.text_files = self.text_files[val_start_index: val_end_index]
        self.midi_files = self.midi_files[val_start_index: val_end_index]
        
    def __len__(self):
        return len(self.midi_files)

    def __getitem__(self, idx):
        text_file = self.text_files[idx]
        with open('dataset/text_partial/' + text_file, 'r') as file:
            text = file.read()

        text = text.replace('\n',', ')
        text_return = self.text_tokenizer(text, return_tensors="pt")
        text_return = torch.squeeze(text_return['input_ids'], 0)

        midi_file = MidiFile('dataset/midi_partial/' + self.midi_files[idx])
        midi = torch.LongTensor(self.midi_tokenizer(midi_file))
        EOS_tensor = torch.LongTensor(music_EOS_token)
        midi_return = torch.cat([midi, EOS_tensor], dim = 0)
        
        return text_return, midi_return

class Test_Dataset(Dataset):
    def __init__(self):
        self.text_files = get_filtered_files('dataset/text_partial', ['.ipynb_checkpoints', '.DS_Store'])
        self.midi_files = get_filtered_files('dataset/midi_partial', ['.ipynb_checkpoints', '.DS_Store'])
        self.text_tokenizer = AutoTokenizer.from_pretrained("roberta-base")
        self.midi_tokenizer = Octuple(TokenizerConfig(use_programs=True))
        
        # Calculate the number of files to use for training (90% of the total)
        self.test_size = int(0.1 * len(self.midi_files))
        self.text_files = self.text_files[-self.test_size:]
        self.midi_files = self.midi_files[-self.test_size:]
        
    def __len__(self):
        return len(self.midi_files)

    def __getitem__(self, idx):
        text_file = self.text_files[idx]
        with open('dataset/text_partial/' + text_file, 'r') as file:
            text = file.read()

        text = text.replace('\n',', ')
        text_return = self.text_tokenizer(text, return_tensors="pt")
        text_return = torch.squeeze(text_return['input_ids'], 0)

        midi_file = MidiFile('dataset/midi_partial/' + self.midi_files[idx])
        midi = torch.LongTensor(self.midi_tokenizer(midi_file))
        EOS_tensor = torch.LongTensor(music_EOS_token)
        midi_return = torch.cat([midi, EOS_tensor], dim = 0)
        
        return text_return, midi_return



In [None]:
def collate_fn(data):
    text, music = zip(*data)
    text = nn.utils.rnn.pad_sequence(text, batch_first=True)
    music = nn.utils.rnn.pad_sequence(music, batch_first=True, padding_value=0)
    return text, music
        
train_data = Train_Dataset()
val_data = Val_Dataset()
test_data = Test_Dataset()

torch.set_printoptions(profile="default")


In [6]:
#normal version
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_data, batch_size=1, shuffle=True, drop_last=True, collate_fn=collate_fn)

text, midi = next(iter(train_dataloader))
print(midi[0])

tensor([[40, 23, 11,  4,  4, 57],
        [47, 23, 11,  4,  4, 57],
        [52, 23, 11,  4,  4, 57],
        ...,
        [ 0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0]])


In [8]:
# TRAINING
text_files  = get_filtered_files('dataset/text_partial', ['.ipynb_checkpoints', '.DS_Store'])
midi_files = get_filtered_files('dataset/midi_partial', ['.ipynb_checkpoints', '.DS_Store'])
folder_path = "dataset/midi_partial/"

print(train_data.midi_tokenizer)
print(val_data.midi_tokenizer)
print(test_data.midi_tokenizer)

[92, 36, 68, 36, 84, 133] tokens with ('T', 'C') io format(one token stream, multi-voc), without BPE
[92, 36, 68, 36, 64, 133] tokens with ('T', 'C') io format(one token stream, multi-voc), without BPE
[92, 36, 68, 36, 64, 133] tokens with ('T', 'C') io format(one token stream, multi-voc), without BPE


In [10]:
mid = train_data.midi_tokenizer("dataset/midi_partial/mmd_99c1fc2ced2ebca9cf8863ad50ac76d7.mid") #assume this format
mid = val_data.midi_tokenizer("dataset/midi_partial/mmd_99c1fc2ced2ebca9cf8863ad50ac76d7.mid")
mid = test_data.midi_tokenizer("dataset/midi_partial/mmd_99c1fc2ced2ebca9cf8863ad50ac76d7.mid")

print(train_data.midi_tokenizer)
print(val_data.midi_tokenizer)
print(test_data.midi_tokenizer)

[92, 36, 68, 36, 325, 133] tokens with ('T', 'C') io format(one token stream, multi-voc), without BPE
[92, 36, 68, 36, 325, 133] tokens with ('T', 'C') io format(one token stream, multi-voc), without BPE
[92, 36, 68, 36, 325, 133] tokens with ('T', 'C') io format(one token stream, multi-voc), without BPE


In [11]:
def softmax_with_temperature(logits, temperature):
    probs = np.exp(logits / temperature) / np.sum(np.exp(logits / temperature))
    return probs


def weighted_sampling(probs):
    probs /= sum(probs)
    sorted_probs = np.sort(probs)[::-1]
    sorted_index = np.argsort(probs)[::-1]
    word = np.random.choice(sorted_index, size=1, p=sorted_probs)[0]
    return word


# -- nucleus -- #
def nucleus(probs, p):
    probs /= (sum(probs) + 1e-5)
    sorted_probs = np.sort(probs)[::-1]
    sorted_index = np.argsort(probs)[::-1]
    cusum_sorted_probs = np.cumsum(sorted_probs)
    after_threshold = cusum_sorted_probs > p
    if sum(after_threshold) > 0:
        last_index = np.where(after_threshold)[0][0] + 1
        candi_index = sorted_index[:last_index]
    else:
        candi_index = sorted_index[:]
    candi_probs = [probs[i] for i in candi_index]
    candi_probs /= sum(candi_probs)
    word = np.random.choice(candi_index, size=1, p=candi_probs)[0]
    return word


def sampling(logit, p=None, t=1.0):
    logit = logit.squeeze().cpu().numpy()
    probs = softmax_with_temperature(logits=logit, temperature=t)
    
    if p is not None:
        cur_word = nucleus(probs, p=p)
    else:
        cur_word = weighted_sampling(probs)
    return cur_word


setup_time = time.time() - start_setup_time
print(f"total setup time: {setup_time}")


total setup time: 2.8696582317352295


# Main section

In [12]:
music_size = len(train_data.midi_tokenizer) 

#FOR EVAL
max_generation_length = 1000
saved_checkpoint_path = 'autoregressive_transpose_text_to_midi_model.pt'


In [13]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model = 768, dropout: float = 0.1, max_len: int = 5000): #changed from 2000-4000 - 30k
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x.permute(1, 0, 2)
        x = x + self.pe[:x.size(0)]
        x = self.dropout(x)
        x = x.permute(1, 0, 2)
        return x


class MusicTransformer(nn.Module):
    def __init__(self, is_training=True):
        super(MusicTransformer, self).__init__()
        self.roberta = RobertaModel.from_pretrained("roberta-base").to(device)
        self.transformer = TransformerEncoderBuilder.from_kwargs(
            n_layers=6, #normally 10...
            n_heads=6,
            query_dimensions=768,
            value_dimensions=128,
            feed_forward_dimensions=768,
            activation='gelu',
            dropout=0.1,
            attention_type="causal-linear",
        ).get()
            
        self.postional_encoding = PositionalEncoding()
        self.embedding_pitch = nn.Embedding(num_embeddings=92, embedding_dim=128)
        self.embedding_velocity = nn.Embedding(num_embeddings=36, embedding_dim=128)
        self.embedding_duration = nn.Embedding(num_embeddings=68, embedding_dim=128)
        self.embedding_position = nn.Embedding(num_embeddings=36, embedding_dim=128)
        self.embedding_bar = nn.Embedding(num_embeddings=len(train_data.midi_tokenizer.vocab[4]), embedding_dim=128)
        self.embedding_program = nn.Embedding(num_embeddings=133, embedding_dim=128)
        
        self.softmax = nn.Softmax(dim=1)
        self.dropout = nn.Dropout(p=0.4) # produces different outputs - avoiding different outputs
        
        self.linear = nn.Linear(768, 256)
    
        self.pitch_output = nn.Linear(256, 92)
        self.velocity_output = nn.Linear(256, 36)
        self.duration_output = nn.Linear(256, 68)
        self.position_output = nn.Linear(256, 36)
        self.bar_output = nn.Linear(256, len(train_data.midi_tokenizer.vocab[4])) # impt cos its variable! - fix this 
        self.program_output = nn.Linear(256, 133)
        
        
        
    def forward(self, text, midi=None, music=None): 
        
        start_token = torch.full((int(batch_size/ngpus), 1, 6), fill_value=1, dtype=torch.int, device=device) # make it VARIABLE according to batch!!
        end_token = torch.full((int(batch_size/ngpus), 1, 6), fill_value=2, dtype=torch.int, device=device)#added this for evaluation

        text_embedding = self.roberta(text).pooler_output
        text_embedding = torch.unsqueeze(text_embedding, 1)
        start_embedding = self.embed(start_token)
        

        
        if midi != None: # change to if is_training
            midi_embedding = self.embed(midi[:, :-1, :])  #batch,seq, 6 --> 2, seq, 768
            input_tensor = torch.cat((text_embedding, start_embedding, midi_embedding), 1)
            pos_embedded_input = self.postional_encoding(input_tensor)
            attn_mask = TriangularCausalMask(pos_embedded_input.size(1), device=device) # new
            output = self.transformer(pos_embedded_input, attn_mask)[:,1:,:] #add attention for training
            
        else: #evaluation
            if music == None:# for first note?
                input_tensor = torch.cat((text_embedding, start_embedding), 1)
            else:
                music_embedding = self.embed(music)
                input_tensor = torch.cat((text_embedding, start_embedding, music_embedding), 1)
                
            pos_embedded_input = self.postional_encoding(input_tensor) #test
            attn_mask = TriangularCausalMask(pos_embedded_input.size(1), device=device) # new
            output = self.transformer(pos_embedded_input,attn_mask)[:,1:,:] #no attention
            

        output = self.linear(output)
        
        pitch_output = self.dropout(self.pitch_output(output))
        velocity_output = self.dropout(self.velocity_output(output))
        duration_output = self.dropout(self.duration_output(output))
        position_output = self.dropout(self.position_output(output))
        bar_output = self.dropout(self.bar_output(output))
        program_output = self.dropout(self.program_output(output))
        
        return pitch_output, velocity_output, duration_output, position_output, bar_output, program_output
   
    
    def embed(self, octuple):
        embedded_pitch = self.embedding_pitch(octuple[..., 0])
        embedded_velocity = self.embedding_velocity(octuple[..., 1])
        embedded_duration = self.embedding_duration(octuple[..., 2])
        embedded_position = self.embedding_position(octuple[..., 3])
        embedded_bar = self.embedding_bar(octuple[..., 4])
        embedded_program = self.embedding_program(octuple[..., 5])
     
        
        embedded_octuple = torch.cat([embedded_pitch, embedded_velocity, embedded_duration,
                              embedded_position, embedded_bar, embedded_program], dim=-1)
        
        return embedded_octuple


music_model = MusicTransformer() #this is the training_version

music_model = music_model.to(device)
for p in music_model.roberta.parameters():
    p.requires_grad = False
    
music_model = DataParallel(music_model, device_ids = [0,1,2,3]) #remove for now... # KEEP TO 0,1.


criterion = torch.nn.CrossEntropyLoss() # cross entropy loss for classification purposes
optimizer = torch.optim.Adam(music_model.parameters(), lr=0.0001)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
import torch.nn.functional as F
import torch.autograd as autograd
import time
from torchmetrics.classification import MulticlassAccuracy

# accuracy = MulticlassAccuracy(num_classes=max_classes).to(device)
accuracy_pitch = MulticlassAccuracy(num_classes=92).to(device)
accuracy_velocity = MulticlassAccuracy(num_classes=36).to(device)
accuracy_duration = MulticlassAccuracy(num_classes=68).to(device)
accuracy_position = MulticlassAccuracy(num_classes=36).to(device)
accuracy_bar_train = MulticlassAccuracy(num_classes=len(train_data.midi_tokenizer.vocab[4])).to(device) #change for 1 for train 1 for val!
accuracy_bar_val = MulticlassAccuracy(num_classes=len(val_data.midi_tokenizer.vocab[4])).to(device) #change for 1 for train 1 for val!
accuracy_program = MulticlassAccuracy(num_classes=133).to(device)

music_model.train()
num_epochs = 10 # use 20
total_batches = len(train_dataloader)
best_val_loss = 9999

softmax = nn.Softmax(dim=1) #softmax after permute for classification.
# dropout = nn.Dropout(p=0.2)

for epoch in range(num_epochs):
    start_time = time.time()  # Record the start time for the epoch
    total_loss = 0.0
    batch_number = 1
    print("\nEpoch: " + str(epoch+1))
    for i, data in enumerate(train_dataloader): #new
        start_batch_time = time.time()  # Record the start time for the epoch
        optimizer.zero_grad()
        
        text, midi = data # new
        text, midi = text.to(device), midi.to(device)
        # print(midi.shape)
        # print(text.shape)
        output = music_model(text, midi)
        
        pitch, velocity, duration, position, bar, program = output
        
        # print(torch.topk(output[0], 1).indices.shape)
        
        pitch = pitch.permute(0, 2, 1)  # [4, seq, 92] --> [4, 92, seq]
        velocity = velocity.permute(0, 2, 1)
        duration = duration.permute(0, 2, 1)
        position = position.permute(0, 2, 1)
        bar = bar.permute(0, 2, 1)
        program = program.permute(0, 2, 1)
        
        # print(torch.topk(pitch, 1).indices.shape)
        loss_pitch = criterion(pitch, midi[..., 0])
        loss_velocity = criterion(velocity, midi[..., 1])
        loss_duration = criterion(duration, midi[..., 2])
        loss_position = criterion(position, midi[..., 3])
        loss_bar = criterion(bar, midi[..., 4])
        loss_program = criterion(program, midi[..., 5])

        total_loss = loss_pitch + loss_velocity + loss_duration + loss_position + loss_bar + loss_program
        total_average_loss = total_loss / 6
        total_average_loss.backward()
        optimizer.step()

        batch_time = time.time() - start_batch_time
        

        #upkeep
        print(f'Epoch [{epoch + 1}/{num_epochs}], Batch [{batch_number}/{total_batches}] took {batch_time:.2f} seconds, Average Loss: {total_average_loss:.4f}, Length: {midi.shape}')
        batch_number += 1
  
        # #validation
        if i % 8 == 0:

            val_start_time = time.time()
            val_text, val_midi = next(iter(val_dataloader))
            
            val_text, val_midi = val_text.to(device), val_midi.to(device)

            #running of model!                       
            val_octuple_output = music_model(val_text, val_midi)
            val_pitch, val_velocity, val_duration, val_position, val_bar, val_program = val_octuple_output
            
            val_pitch = val_pitch.permute(0, 2, 1)
            val_velocity = val_velocity.permute(0, 2, 1)
            val_duration = val_duration.permute(0, 2, 1)
            val_position = val_position.permute(0, 2, 1)
            val_bar = val_bar.permute(0, 2, 1)
            val_program = val_program.permute(0, 2, 1)


            val_loss_pitch = criterion(val_pitch, val_midi[..., 0])
            val_loss_velocity = criterion(val_velocity, val_midi[..., 1])
            val_loss_duration = criterion(val_duration, val_midi[..., 2])
            val_loss_position = criterion(val_position, val_midi[..., 3])
            val_loss_bar = criterion(val_bar, val_midi[..., 4])
            val_loss_program = criterion(val_program, val_midi[..., 5])
            
            
            pitch_accuracy = accuracy_pitch(val_pitch, val_midi[..., 0])
            velocity_accuracy = accuracy_velocity(val_velocity, val_midi[..., 1])
            duration_accuracy = accuracy_duration(val_duration, val_midi[..., 2])
            position_accuracy = accuracy_position(val_position, val_midi[..., 3])
            bar_accuracy = accuracy_bar_train(val_bar, val_midi[..., 4])
            program_accuracy = accuracy_program(val_program, val_midi[..., 5])

            print(f'* Accuracy | Pitch: [{pitch_accuracy.item():.4f}], Velocity: [{velocity_accuracy.item():.4f}], Duration: [{duration_accuracy.item():.4f}], Position: [{position_accuracy.item():.4f}], Bar: [{bar_accuracy.item():.4f}], Program: [{program_accuracy.item():.4f}]')
        
            val_total_loss = val_loss_pitch + val_loss_velocity + val_loss_duration + val_loss_position + val_loss_bar + val_loss_program
            val_total_average_loss = val_total_loss / 6
            
            if val_total_average_loss < best_val_loss:
                best_val_loss = val_total_average_loss
                torch.save(music_model.state_dict(), saved_checkpoint_path)
            val_time = time.time() - val_start_time
            print(f'* Epoch [{epoch + 1}/{num_epochs}], Validation took {val_time:.2f} seconds, Validation Loss: {val_total_average_loss:.4f}, Best Validation Loss: {best_val_loss:.4f}')
    

    epoch_time = time.time() - start_time
    print(f"Epoch {epoch+1} took {epoch_time:.2f} seconds")
    
    for g in optimizer.param_groups:
        g['lr'] *= 0.8 #was 0.7 then 0.6 most of the time
# print(torch.topk(output[0], 1).indices)


# Saving and evaluating - check for checkpoint error

In [20]:

from collections import OrderedDict

def load_saved_model(saved_checkpoint_path):
    device = get_device()
    saved_transformer_model = MusicTransformer(is_training=False) # for evaluation
    device = torch.device(device) 

    print(saved_checkpoint_path)
    saved_transformer_model = saved_transformer_model.to(device)
    state_dict = torch.load(saved_checkpoint_path)

    # Load the state dictionary into the model
    new_state_dict = OrderedDict()
    for key, value in state_dict.items():
        name = key[7:]  # remove the "module." prefix
        new_state_dict[name] = value

    # Load the state dictionary into the model
    saved_transformer_model.load_state_dict(new_state_dict, strict=False)
    
    return saved_transformer_model


def prepare_text_input(input_string):
    text_tokenizer = AutoTokenizer.from_pretrained("roberta-base")
    text_input = text_tokenizer(input_string, return_tensors="pt")
    text_input = torch.squeeze(text_input['input_ids'], 0)
    text_input = text_input.reshape((1, -1))
    text_input = text_input.to(device)
    
    return text_input



# the format of the output vs the midi quite different.
# more for batchsize of 1 for now!!
def shape_output_to_tensor(output):
    
    pitch, velocity, duration, position, bar, program = output # pitch = batch, seq, class
    # print(pitch)
    
    if batch_size/ngpus > 1:
        pitch = pitch[-1,:,:].unsqueeze(0)
        velocity = velocity[-1,:,:].unsqueeze(0)
        duration = duration[-1,:,:].unsqueeze(0)
        position = position[-1,:,:].unsqueeze(0)
        bar = bar[-1,:,:].unsqueeze(0)
        program = program[-1,:,:].unsqueeze(0)
    

    #most updated hyperparams
    cur_pitch = sampling(pitch[:,-1,:], t=1.2, p=0.5) 
    cur_velocity = sampling(velocity[:,-1,:], t=2)
    cur_duration = sampling(duration[:,-1,:], t=1.25, p=0.5)
    cur_position = sampling(position[:,-1,:], t=1.5, p=0.5)  
    cur_bar = sampling(bar[:,-1,:], t=0.6)
    cur_program = sampling(program[:,-1,:], t=1.25, p=0.45)
    
    cur_bar = torch.topk(bar[:,-1,:], 1).indices
    cur_bar = cur_bar.to(torch.device("cpu")).numpy().astype(np.int64).item() 

    next_arr = np.array([
            cur_pitch,
            cur_velocity,
            cur_duration,
            cur_position,
            cur_bar,
            cur_program,
            ])      
    
        
    return torch.tensor(next_arr)

### Evaluation Code - Autoregressive

In [21]:

# input_string = "slow sad song, with a long flowing melody line"
# input_string = "Classical love song, upbeat tempo, and complex harmony."
input_string = "Dramatic film score, with romantic melodies"

saved_checkpoint_path_eval = 'autoregressive_transpose_text_to_midi_model.pt'
saved_transformer_model = load_saved_model(saved_checkpoint_path_eval)
text_input = prepare_text_input(input_string)


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


autoregressive_22feb_transpose_text_to_midi_model.pt


In [22]:
music_model.eval()
saved_transformer_model.eval()

max_generation_length = 800

with torch.no_grad():
    if batch_size / ngpus == 2: # alternates between batch size of 2 for the model, and 1 for the output
        duplicated_text_input = np.tile(text_input.to(torch.device("cpu")), (2, 1))
        duplicated_text_input = torch.tensor(duplicated_text_input, device=device)
        step_output = saved_transformer_model(text=duplicated_text_input) 
        
        eval_output = shape_output_to_tensor(step_output).to(device) #convert 2,1,96 etc to 1,1,96
        midi_output= eval_output.unsqueeze(0).unsqueeze(0)
        duplicated_midi_output = np.tile(midi_output.to(torch.device("cpu")), (2, 1, 1))
        duplicated_midi_output = torch.tensor(duplicated_midi_output, device=device) #convert back to 2 batch size for the model
        
        for i in range(max_generation_length-1):
            step_output = saved_transformer_model(text=duplicated_text_input, music = duplicated_midi_output)
            eval_output = shape_output_to_tensor(step_output).to(device)
                
            if i %50 == 0:        
                print(midi_output.shape)
            
            midi_output= torch.cat((midi_output, eval_output.unsqueeze(0).unsqueeze(0)),dim=1)
            duplicated_midi_output = np.tile(midi_output.to(torch.device("cpu")), (2, 1, 1))
            duplicated_midi_output = torch.tensor(duplicated_midi_output, device=device)
            
            
    
    
    if batch_size / ngpus == 1: #when its just 4
        step_output = saved_transformer_model(text=text_input)      #tuple of 6 : 1,1,92
        eval_output = shape_output_to_tensor(step_output).to(device) #dimension = [6]
        midi_output= eval_output.unsqueeze(0).unsqueeze(0)
        
        for i in range(max_generation_length-1):
            step_output = saved_transformer_model(text=text_input, music = midi_output) #convert to tensor
            eval_output = shape_output_to_tensor(step_output).to(device)
            
            if i %50 == 0:        
                print(midi_output.shape)
            
            midi_output= torch.cat((midi_output, eval_output.unsqueeze(0).unsqueeze(0)),dim=1)
    
    print(midi_output.shape)
    print("done")
    



torch.Size([1, 1, 6])
torch.Size([1, 51, 6])
torch.Size([1, 101, 6])
torch.Size([1, 151, 6])
torch.Size([1, 201, 6])
torch.Size([1, 251, 6])
torch.Size([1, 301, 6])
torch.Size([1, 351, 6])
torch.Size([1, 401, 6])
torch.Size([1, 451, 6])
torch.Size([1, 501, 6])
torch.Size([1, 551, 6])
torch.Size([1, 601, 6])
torch.Size([1, 651, 6])
torch.Size([1, 701, 6])
torch.Size([1, 751, 6])
torch.Size([1, 800, 6])
done


In [23]:
torch.set_printoptions(profile="full")
print(midi_output)
torch.set_printoptions(profile="default")


tensor([[[ 20,  28,  27,   4,   4, 100],
         [  7,  13,  67,   4,   4,  89],
         [ 43,  24,  27,   4,   4,  25],
         [ 55,  26,  27,   4,   4,  73],
         [ 67,  25,  11,   4,   4,  57],
         [ 55,  28,  50,  12,   4,  57],
         [ 59,  16,  27,  20,   4,  73],
         [ 38,  31,  27,  20,   4,  47],
         [ 59,  28,  27,  20,   4,  59],
         [ 35,  23,  27,  20,   4,  65],
         [ 43,  20,  27,  20,   4,  19],
         [ 52,  24,  27,  20,   4,  73],
         [ 55,  16,  27,  28,   4, 114],
         [ 59,  17,  27,  28,   4,  40],
         [ 38,  28,  27,  28,   4,  73],
         [ 57,  20,  27,  28,   4,   5],
         [ 59,  21,  27,  28,   4,  51],
         [ 40,  20,  19,  28,   4,  24],
         [ 57,  21,  19,  28,   4,  20],
         [ 59,  13,  27,  32,   4,  31],
         [ 31,  25,  27,  32,   5,  73],
         [ 40,  23,  27,  32,   5,  53],
         [ 52,  23,  27,  32,   5,  73],
         [ 31,  29,  11,   4,   5,   5],
         [ 43,  

In [24]:
output_Octuple = []

for row in midi_output[0]:
    tokens = []
    for idx, vocab in zip(row, train_data.midi_tokenizer.vocab):
        # Use next() with a default value to avoid StopIteration - ensures that value is within range
        token = next((key for key, value in vocab.items() if value == min(idx, len(vocab.items())-1)), None)

        if token is None:
            # If idx is not found, convert vocabulary keys to numeric type and find the closest
            print("Error detected:")
            print(min(idx, len(vocab.items())))
        tokens.append(token)
    output_Octuple.append(tokens)

for i, row in enumerate(output_Octuple):
    print(row)
    if i == 4:
        break


['Pitch_37', 'Velocity_99', 'Duration_3.0.8', 'Position_0', 'Bar_0', 'Program_95']
['Pitch_24', 'Velocity_39', 'Duration_12.0.4', 'Position_0', 'Bar_0', 'Program_84']
['Pitch_60', 'Velocity_83', 'Duration_3.0.8', 'Position_0', 'Bar_0', 'Program_20']
['Pitch_72', 'Velocity_91', 'Duration_3.0.8', 'Position_0', 'Bar_0', 'Program_68']
['Pitch_84', 'Velocity_87', 'Duration_1.0.8', 'Position_0', 'Bar_0', 'Program_52']


In [25]:
#Creating the midi file! from octuple
gen_midi = train_data.midi_tokenizer(output_Octuple)
gen_midi.dump('transformer_output.mid')
