### Step 1: Lakh to NES (make MIDI files compatible to NES-MDB format)

In [1]:
import mido
import numpy as np
import os

def midi_to_tokens(midi_file):
    # Load the MIDI file
    midi = mido.MidiFile(midi_file)

    # Initialize variables
    tokens = []
    current_time = 0
    note_status = np.zeros(128)  # MIDI note range

    # Iterate through MIDI messages
    for msg in midi:
        # Update the current time
        current_time += msg.time

        # Process note on/off events
        if msg.type == 'note_on' and msg.velocity > 0:
            # Note on event
            tokens.append({'time': current_time, 'note': msg.note, 'velocity': msg.velocity})
            note_status[msg.note] = 1
        elif (msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0)) and note_status[msg.note] == 1:
            # Note off event
            tokens.append({'time': current_time, 'note': msg.note, 'velocity': 0})
            note_status[msg.note] = 0

    return tokens

def process_midi_directory(directory):
    all_tokens = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith('.mid') or file.endswith('.midi'):
                midi_file = os.path.join(root, file)
                try:
                    tokens = midi_to_tokens(midi_file)
                    if tokens:
                        all_tokens.append(tokens)
                except (EOFError, OSError, ValueError) as e:
                    print(e)
    return all_tokens

### Step 2: Implement Transformers Model

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerModel(nn.Module):
    def __init__(self, ntoken, d_model, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, nhid, dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, ntoken)

        self.init_weights()

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, src_mask):
        src = self.encoder(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        return output

### Step 3: Pre-Training the Transformers Model

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence

In [15]:
MAX_SEQ_LEN = 512  # Set a maximum sequence length

class LakhMIDIDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        sequence = self.data[index][:MAX_SEQ_LEN]  # Truncate sequence to MAX_SEQ_LEN
        input_sequence = [token['note'] for token in sequence[:-1]]
        target_sequence = [token['note'] for token in sequence[1:]]
        return torch.tensor(input_sequence, dtype=torch.long), torch.tensor(target_sequence, dtype=torch.long)

    def __len__(self):
        return len(self.data)

In [16]:
# Process all MIDI files in the directory
midi_directory = 'midi_dataset_small'
all_tokens = process_midi_directory(midi_directory)

# Flatten the list of token sequences and extract unique tokens
flattened_tokens = [(token['note'], token['velocity']) for sublist in all_tokens for token in sublist]
unique_tokens = set(flattened_tokens)
ntoken = len(unique_tokens)  # Number of unique tokens in dataset

data byte must be in range 0..127
data byte must be in range 0..127
data byte must be in range 0..127


In [17]:
# Set model parameters
d_model = 512  # Dimension of the model
nhead = 8  # Number of heads in the multi-head attention mechanism
nhid = 2048  # Dimension of the feedforward network model in the Transformer layers
nlayers = 6  # Number of Transformer encoder layers
dropout = 0.1  # Dropout rate

# Initialize the Transformer model
transformer_model = TransformerModel(ntoken, d_model, nhead, nhid, nlayers, dropout)

In [18]:
# Load the Lakh MIDI dataset
lakh_dataset = LakhMIDIDataset(all_tokens)

# Define a custom collate function for padding
def pad_collate(batch):
    (xx, yy) = zip(*batch)
    x_lengths = [len(x) for x in xx]
    y_lengths = [len(y) for y in yy]
    x_padded = pad_sequence(xx, batch_first=True, padding_value=0)
    y_padded = pad_sequence(yy, batch_first=True, padding_value=0)
    return x_padded, y_padded, x_lengths, y_lengths

# Create a DataLoader for the Lakh MIDI dataset
lakh_loader = DataLoader(lakh_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_collate)

In [19]:
# Set training hyperparameters
BATCH_SIZE = 64
NUM_EPOCHS = 3
LEARNING_RATE = 0.001

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(transformer_model.parameters(), lr=LEARNING_RATE)

In [22]:
# Move the model to the appropriate device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transformer_model.to(device)

# Pre-training loop
for epoch in range(NUM_EPOCHS):
    transformer_model.train()
    total_loss = 0

    for batch_idx, (inputs, targets, input_lengths, target_lengths) in enumerate(lakh_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        batch_size, seq_len = inputs.size()
        mask = torch.zeros((batch_size, seq_len), device=device).to(torch.bool)
        for i, length in enumerate(input_lengths):
            mask[i, :length] = 1
        mask = mask == 0  # Invert mask: True for padding, False for actual data
        mask = mask.unsqueeze(1) * mask.unsqueeze(2)  # Create a 3D mask: (batch_size, seq_len, seq_len)
        mask = mask.repeat_interleave(nhead, dim=0)  # Repeat mask for each head: (batch_size * num_heads, seq_len, seq_len)

        outputs = transformer_model(inputs, mask)
        loss = criterion(outputs.view(-1, ntoken), targets.view(-1))
        total_loss += loss.item()
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Batch {batch_idx+1}/{len(lakh_loader)}, Loss: {loss.item()}")

    print(f"Epoch {epoch+1} completed, Average Loss: {total_loss / len(lakh_loader)}")

# Save the pre-trained model
torch.save(transformer_model.state_dict(), "pretrained_transformer.pth")

RuntimeError: The shape of the 3D attn_mask is torch.Size([256, 511, 511]), but should be (4088, 32, 32).