In [1]:
import torch
import torch.nn as nn
from miditok import MIDILike
from miditok.pytorch_data import DatasetMIDI, DataCollator
from miditok.utils import split_files_for_training
from torch.utils.data import DataLoader
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import time
import numpy as np
@torch.no_grad()
def measure_latency(model, dummy_input, n_warmup=20, n_test=100):
    model.eval()
    # warmup
    for _ in range(n_warmup):
        _ = model(dummy_input)
    # real test
    latencies = []
    for _ in range(n_test):
        t1 = time.time()
        _ = model(dummy_input)
        t2 = time.time()
        latencies += [t2 - t1]

    return np.mean(latencies), np.std(latencies)  # average latency

In [2]:
class MusicLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, fc_dim, device):
        super(MusicLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.ModuleList([
                nn.LSTM(embedding_dim if i == 0 else hidden_dim, hidden_dim, num_layers = 1, batch_first = True)
            for i in range(num_layers)
        ])
        self.fc1 = nn.Linear(hidden_dim, vocab_size)
        self.relu = nn.ReLU()
        self.device = device
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim

    def forward(self, x, hidden = None, **kwargs):
        if hidden == None:
            hidden = self.init_hidden(x.shape[0])
        x = self.embedding(x)
        for i, lstm_layer in enumerate(self.lstm):
            x, hidden = lstm_layer(x, hidden)
        out = self.fc1(x)
        return out, hidden

    def init_hidden(self, batch_size):
            return (torch.zeros(1, batch_size, self.hidden_dim).to(self.device),
                    torch.zeros(1, batch_size, self.hidden_dim).to(self.device))

In [None]:
import itertools
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
print("hello")
# Define parameters
vocab_size = 398  # example vocabulary size
embedding_dim = 128  # size of embedding dimension
hidden_dim = 512  # size of hidden dimension
num_layers = 10  # number of LSTM layers
batch_size = 1
seq_length = 1
device = 'cuda' if torch.cuda.is_available() else 'cpu'
fc_dim = 128

# Define layer and embedding configurations
layer_options = range(5, 11)  # number of LSTM layers from 1 to 10
embedding_options = [2**i for i in range(5, 11) if 2**i <= 512] 
# Initialize a matrix to store latency values
latency_matrix = np.zeros((len(layer_options), len(embedding_options)))

# Loop through combinations of num_layers and embedding_dim
for i, num_layers in tqdm(enumerate(layer_options)):
    for j, embedding_dim in tqdm(enumerate(embedding_options)):
        # Instantiate the model
        model = MusicLSTM(vocab_size, embedding_dim, hidden_dim, num_layers, fc_dim, device).to(device)
        
        # Define input tokens
        input_tokens = torch.randint(0, vocab_size, (batch_size, seq_length)).to(device)
        
        # Initialize hidden state
        hidden = model.init_hidden(batch_size)
        
        # Measure latency
        m, _ = measure_latency(model, input_tokens, n_test=1000)
        
        # Store mean latency in milliseconds in the matrix
        latency_matrix[i, j] = m * 1000  # convert seconds to milliseconds

# Create a heatmap
plt.figure(figsize=(10, 8))
ax = sns.heatmap(latency_matrix, annot=True, fmt=".2f", cmap="RdYlGn_r", 
                 cbar_kws={'label': 'Mean Latency (ms)'},
                 xticklabels=embedding_options, yticklabels=layer_options)
ax.set_xlabel("Embedding Dimension")
ax.set_ylabel("Number of Layers")
plt.title("Mean Latency (ms) for Different Embedding Sizes and Layer Counts")
plt.show()

In [29]:
# Instantiate the model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MusicLSTM(390, 768, 768, 3, 1, device).to(device)
model.load_state_dict(torch.load("models/bs16_dropout_guided_new_music_lstm.pt", map_location=torch.device('cpu')))



<All keys matched successfully>

In [17]:
import torch
import numpy as np
from pathlib import Path
from tqdm import tqdm

tokenizer = MIDILike()
music = torch.Tensor(tokenizer('train_midi.mid')).long().unsqueeze(0)
hidden = None
generated_length = 400
temperature = 1

# Giving the whole context every single time
with torch.no_grad():
    generated_music = music.clone().squeeze(0)
    print(music.shape)
    for _ in tqdm(range(generated_length)):
        outputs, hidden = model(generated_music, hidden)
        
        probabilities = torch.softmax(outputs[:, -1, :] / temperature, dim=-1)
        next_token = torch.multinomial(probabilities, num_samples=1)
        
        # Context is collected and fed back into the model
        generated_music = torch.cat((generated_music, next_token), dim=1)

# Convert back to MIDI format
generated_midi = tokenizer(generated_music)
generated_midi.dump_midi(Path("whole_context.mid"))



torch.Size([1, 1, 199])


  2%|▎         | 10/400 [00:04<02:42,  2.40it/s]


KeyboardInterrupt: 

In [32]:
import numpy as np
import torch
from pathlib import Path
from processor import encode_midi
from processor import decode_midi

from processor import encode_midi
pre_music = encode_midi('train_midi.mid')
music = torch.Tensor(pre_music).long()

added_toks = []
music_next = music.clone().unsqueeze(0)
hidden = None
temperature = 0.4
generated_length = 400

with torch.no_grad():
    for i in range(generated_length):
        outputs, hidden = model(music_next, hidden)

        probabilities = torch.softmax(outputs[:, -1, :] / temperature, dim=-1)
        next_token = torch.multinomial(probabilities, num_samples=1).squeeze()
        
        music_next = next_token.unsqueeze(0).unsqueeze(0)
        added_toks.append(next_token.unsqueeze(0).tolist()[0])


prompt_gen = pre_music + added_toks


decode_midi(prompt_gen, 'new_output.mid') 

info removed pitch: 64
info removed pitch: 66
info removed pitch: 89
info removed pitch: 63
info removed pitch: 78
info removed pitch: 63
info removed pitch: 78
info removed pitch: 61
info removed pitch: 59
info removed pitch: 76
info removed pitch: 61
info removed pitch: 71
info removed pitch: 88
info removed pitch: 82
info removed pitch: 61
info removed pitch: 61
info removed pitch: 63


<pretty_midi.pretty_midi.PrettyMIDI at 0x1205e5840>

In [33]:
import pretty_midi
import IPython.display
import numpy as np
# Load MIDI file into PrettyMIDI object
midi_data = pretty_midi.PrettyMIDI('new_output.mid')

# Synthesize the resulting MIDI data using sine waves
sr = 48000
IPython.display.Audio(midi_data.synthesize(), rate=sr)