In [1]:
#Smaller sequences to speed up the training
sequence_limit=None
num_epochs=100
sequence_length=16
batch_size=16

In [2]:
from sklearn.preprocessing import MinMaxScaler
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pretty_midi
import glob
from torch.utils.data import DataLoader, TensorDataset

# Define scalers for velocity and timing features
velocity_scaler = MinMaxScaler(feature_range=(0, 1))
time_scaler = MinMaxScaler(feature_range=(0, 1))
np.set_printoptions(suppress=True)


In [3]:
# Function to extract MIDI features
def extract_midi_features(midi_file):
    midi_data = pretty_midi.PrettyMIDI(midi_file)
    notes = []  # Initialize notes inside the function
    for instrument in midi_data.instruments:
        for note in instrument.notes:
            notes.append([
                note.pitch, 
                note.velocity,
                note.start, 
                note.end-note.start
            ])
    
    return np.array(notes)  # Return all notes

# Load all MIDI files
midi_files = glob.glob("miniscule-dataset/*.midi")
data = []
for f in midi_files:
    notes = extract_midi_features(f)
    for i in range(0, len(notes) - sequence_length-1):
        data.append(notes[i:i+sequence_length])  # Create 16-note sequences
#this line of code ensures that data only contains elements with the shape (sequence_length, 4), filtering out any elements that do not match this shape.
data = [d for d in data if d.shape == (sequence_length, 4)]  # Filter inconsistent samples


In [14]:
if sequence_limit is not None:
    data = data[:sequence_limit]
dataset = torch.tensor(np.array(data), dtype=torch.float32)
print(dataset.shape)

#print(dataset[0])

# Print the tensor without scientific notation
#print(dataset[990].numpy())
print(dataset[990].numpy())

torch.Size([3140, 16, 4])
[[81.         97.          3.9010417   0.02734375]
 [38.         52.          4.013021    0.08723959]
 [45.         64.          4.0638022   0.04947917]
 [50.         89.          4.0989585   0.03515625]
 [62.         95.          4.0950522   0.04036458]
 [74.         87.          4.3020835   0.02734375]
 [72.         40.          4.3398438   0.03385417]
 [76.         90.          4.3841147   0.06640625]
 [77.         81.          4.4830728   0.05598958]
 [78.         30.          4.484375    0.07552084]
 [79.         69.          4.5247397   0.07942709]
 [81.         96.          4.5872397   0.0234375 ]
 [69.         86.          4.783854    0.02994792]
 [45.         44.          4.7890625   0.046875  ]
 [47.         57.          4.8333335   0.10286459]
 [81.         78.          4.955729    0.04036458]]


In [15]:
# Create DataLoader for batch training
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Transformer Model with Positional Encoding
class PositionalEncoding(nn.Module):
    #def __init__(self, d_model, max_len=16):
    def __init__(self, d_model, max_len=sequence_length):

        super(PositionalEncoding, self).__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe.to(x.device)

class MidiTransformer(nn.Module):
    def __init__(self, input_dim=4, model_dim=128, num_heads=4, num_layers=3, ff_dim=512):
        super(MidiTransformer, self).__init__()
        self.embedding = nn.Linear(input_dim, model_dim)
        self.pos_encoder = PositionalEncoding(model_dim)
        encoder_layers = nn.TransformerEncoderLayer(
            d_model=model_dim, nhead=num_heads, dim_feedforward=ff_dim
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
        self.fc = nn.Linear(model_dim, input_dim)  

    def forward(self, x):
        x = self.embedding(x)  
        x = self.pos_encoder(x)  
        x = self.transformer_encoder(x)  
        x = self.fc(x)  
        return x


In [16]:
# Model, loss function, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MidiTransformer().to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop with batches
for epoch in range(num_epochs):
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)  # Move batch to GPU
        optimizer.zero_grad()
        outputs = model(batch)  
        loss = criterion(outputs, batch)  
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {total_loss / len(train_loader)}")




Epoch 0, Loss: 33.18935974237277
Epoch 10, Loss: 2.2972006507331346
Epoch 20, Loss: 1.8873484503799283


KeyboardInterrupt: 

In [38]:
# Generate a new MIDI sequence from a random input
random_index = np.random.randint(0, len(dataset))
random_sequence = dataset[random_index].unsqueeze(0)
generated_sequence = model(random_sequence).detach().numpy()


In [42]:
# Convert generated sequence to MIDI
def sequence_to_midi(sequence, output_file="generated.mid"):
    midi = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(program=0)  

    # Adjust the start times to ensure the sequence starts at 0
    min_start_time = min(note_data[2] for note_data in sequence)
    
    for note_data in sequence:
        pitch, velocity = map(int, note_data[:2])
        start, duration = map(float, note_data[2:])
        start -= min_start_time  # Adjust start time to ensure it starts at 0
        note = pretty_midi.Note(
            velocity=max(0, min(127, velocity)),  
            pitch=max(0, min(127, pitch)),  
            start=max(0, start),
            end=max(0, start+duration)
        )
        instrument.notes.append(note)

    midi.instruments.append(instrument)
    midi.write(output_file)
    for noteInformation in midi.instruments[0].notes:
        print(noteInformation)
    print(f"Generated MIDI saved as {output_file}")
sequence_to_midi(generated_sequence[0], "generated.mid")


Note(start=8.854250, end=8.968792, pitch=72, velocity=79)
Note(start=8.315794, end=8.440536, pitch=73, velocity=69)
Note(start=7.926090, end=8.021977, pitch=73, velocity=73)
Note(start=1.930403, end=2.035871, pitch=75, velocity=70)
Note(start=2.401875, end=2.516029, pitch=76, velocity=63)
Note(start=0.000000, end=0.121437, pitch=76, velocity=68)
Note(start=3.496042, end=3.612642, pitch=71, velocity=80)
Note(start=3.335733, end=3.443284, pitch=75, velocity=74)
Note(start=2.097795, end=2.222771, pitch=75, velocity=76)
Note(start=4.751883, end=4.853800, pitch=77, velocity=67)
Note(start=5.770882, end=5.861656, pitch=69, velocity=82)
Note(start=2.962078, end=3.050597, pitch=69, velocity=80)
Note(start=2.335300, end=2.393865, pitch=70, velocity=88)
Note(start=4.899960, end=5.011820, pitch=74, velocity=75)
Note(start=4.388321, end=4.490726, pitch=77, velocity=69)
Note(start=1.978020, end=2.091052, pitch=76, velocity=74)
Generated MIDI saved as generated.mid
