In [165]:
sequenceLimit=500

In [166]:
from sklearn.preprocessing import MinMaxScaler
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pretty_midi
import glob
from torch.utils.data import DataLoader, TensorDataset

# Define scalers for velocity and timing features
velocity_scaler = MinMaxScaler(feature_range=(0, 1))
time_scaler = MinMaxScaler(feature_range=(0, 1))
np.set_printoptions(suppress=True)


In [167]:

# Function to extract MIDI features
def extract_midi_features(midi_file):
    midi_data = pretty_midi.PrettyMIDI(midi_file)
    notes = []  # Initialize notes inside the function
    for instrument in midi_data.instruments:
        for note in instrument.notes:
            notes.append([
                note.pitch, 
                note.velocity,
                note.start, 
                note.end-note.start 
            ])
    
    return np.array(notes)  # Return all notes

# Load all MIDI files
midi_files = glob.glob("smaller-dataset/*.midi")
data = []
for f in midi_files:
    notes = extract_midi_features(f)
    for i in range(0, len(notes) - 15):
        data.append(notes[i:i+16])  # Create 16-note sequences
data = [d for d in data if d.shape == (16, 4)]  # Filter inconsistent samples


In [168]:
data = data[:sequenceLimit]
dataset = torch.tensor(np.array(data), dtype=torch.float32)
np.set_printoptions(suppress=True)

print(dataset.shape)
print(dataset[0])

torch.Size([500, 16, 4])
tensor([[7.4000e+01, 9.2000e+01, 1.0234e+00, 6.2500e-02],
        [5.7000e+01, 7.9000e+01, 2.0312e+00, 6.3802e-02],
        [6.2000e+01, 8.6000e+01, 2.5339e+00, 4.2969e-02],
        [8.1000e+01, 9.3000e+01, 1.5456e+00, 1.0534e+00],
        [7.4000e+01, 8.2000e+01, 3.0247e+00, 4.4661e-01],
        [7.8000e+01, 9.7000e+01, 3.0156e+00, 5.1432e-01],
        [7.3000e+01, 7.3000e+01, 3.5182e+00, 4.4271e-02],
        [7.6000e+01, 7.9000e+01, 3.5273e+00, 4.9479e-02],
        [7.1000e+01, 7.8000e+01, 3.6615e+00, 3.9062e-02],
        [7.4000e+01, 7.2000e+01, 3.6667e+00, 4.8177e-02],
        [6.9000e+01, 7.6000e+01, 3.7969e+00, 3.1250e-02],
        [7.3000e+01, 7.9000e+01, 3.7852e+00, 7.1615e-02],
        [6.7000e+01, 7.8000e+01, 3.9284e+00, 3.7760e-02],
        [7.1000e+01, 8.3000e+01, 3.9271e+00, 5.7292e-02],
        [6.6000e+01, 7.6000e+01, 4.0638e+00, 3.5156e-02],
        [6.9000e+01, 7.7000e+01, 4.0625e+00, 5.0781e-02]])


In [169]:
# Create DataLoader for batch training
batch_size = 32
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Transformer Model with Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=16):
        super(PositionalEncoding, self).__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe.to(x.device)

class MidiTransformer(nn.Module):
    def __init__(self, input_dim=4, model_dim=128, num_heads=4, num_layers=3, ff_dim=512):
        super(MidiTransformer, self).__init__()
        self.embedding = nn.Linear(input_dim, model_dim)
        self.pos_encoder = PositionalEncoding(model_dim)
        encoder_layers = nn.TransformerEncoderLayer(
            d_model=model_dim, nhead=num_heads, dim_feedforward=ff_dim
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
        self.fc = nn.Linear(model_dim, input_dim)  

    def forward(self, x):
        x = self.embedding(x)  
        x = self.pos_encoder(x)  
        x = self.transformer_encoder(x)  
        x = self.fc(x)  
        return x


In [170]:
# Model, loss function, and optimizer
model = MidiTransformer()
criterion = nn.MSELoss()
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop with batches
num_epochs = 100
for epoch in range(num_epochs):
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        outputs = model(batch)  
        loss = criterion(outputs, batch)  
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {total_loss / len(train_loader)}")




Epoch 0, Loss: 39.35063195228577
Epoch 10, Loss: 19.882048726081848
Epoch 20, Loss: 3.9847057461738586
Epoch 30, Loss: 2.1883397921919823
Epoch 40, Loss: 1.5797059908509254
Epoch 50, Loss: 1.2809944078326225
Epoch 60, Loss: 1.150880753993988
Epoch 70, Loss: 1.0622083991765976
Epoch 80, Loss: 1.0230532139539719
Epoch 90, Loss: 0.9688814021646976


In [171]:
# Generate a new MIDI sequence from a random input

random_index = np.random.randint(0, len(dataset))
random_sequence = dataset[random_index].unsqueeze(0)

print(random_sequence) 


tensor([[[7.4000e+01, 8.7000e+01, 4.7186e+01, 9.5052e-02],
         [7.6000e+01, 7.6000e+01, 4.7268e+01, 9.6354e-02],
         [6.2000e+01, 8.1000e+01, 4.7331e+01, 5.4688e-02],
         [7.8000e+01, 9.9000e+01, 4.7339e+01, 7.0312e-02],
         [6.4000e+01, 8.0000e+01, 4.7579e+01, 5.7292e-02],
         [8.0000e+01, 8.7000e+01, 4.7579e+01, 9.5052e-02],
         [6.6000e+01, 8.9000e+01, 4.7810e+01, 7.0312e-02],
         [8.1000e+01, 8.3000e+01, 4.7809e+01, 1.2370e-01],
         [8.0000e+01, 7.8000e+01, 4.7909e+01, 1.2109e-01],
         [7.8000e+01, 8.2000e+01, 4.8000e+01, 1.6146e-01],
         [7.6000e+01, 7.9000e+01, 4.8104e+01, 9.2448e-02],
         [6.2000e+01, 8.0000e+01, 4.8208e+01, 8.2031e-02],
         [7.4000e+01, 8.3000e+01, 4.8208e+01, 1.1719e-01],
         [7.3000e+01, 8.3000e+01, 4.8324e+01, 1.2760e-01],
         [7.1000e+01, 7.9000e+01, 4.8431e+01, 9.1146e-02],
         [6.9000e+01, 7.6000e+01, 4.8535e+01, 1.5365e-01]]])


In [172]:
generated_sequence = model(random_sequence).detach().numpy()
generated_sequence

array([[[69.67848   , 77.24085   , 43.734467  ,  0.08995359],
        [71.08682   , 70.62157   , 43.540554  ,  0.0734554 ],
        [62.674156  , 77.625496  , 45.823338  ,  0.03369004],
        [69.77228   , 85.28928   , 41.819225  ,  0.06462853],
        [64.100845  , 78.56765   , 47.167805  , -0.00189007],
        [72.8452    , 75.545845  , 42.21764   ,  0.09657156],
        [63.909595  , 80.037926  , 44.594505  ,  0.03622627],
        [70.967255  , 71.6163    , 40.73186   ,  0.09243037],
        [71.590485  , 68.77719   , 42.228886  ,  0.13451473],
        [70.71718   , 71.53368   , 42.364616  , -0.01942845],
        [70.60805   , 71.46923   , 42.957943  ,  0.08632839],
        [63.822342  , 74.28846   , 46.749096  ,  0.08174405],
        [69.42537   , 76.87594   , 44.60462   ,  0.03748   ],
        [69.20035   , 76.25457   , 45.10779   ,  0.04036448],
        [68.52671   , 73.507225  , 45.405888  , -0.0301188 ],
        [67.16599   , 70.19746   , 45.853695  ,  0.04409356]]],
      

In [175]:
# Convert generated sequence to MIDI
def sequence_to_midi(sequence, output_file="generated.mid"):
    midi = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(program=0)  

    # Adjust the start times to ensure the sequence starts at 0
    min_start_time = min(note_data[2] for note_data in sequence)
    
    for note_data in sequence:
        pitch, velocity = map(int, note_data[:2])
        start, duration = map(float, note_data[2:])
        start -= min_start_time  # Adjust start time to ensure it starts at 0
        note = pretty_midi.Note(
            velocity=max(0, min(127, velocity)),  
            pitch=max(0, min(127, pitch)),  
            start=start,
            end=start + duration
        )
        instrument.notes.append(note)

    midi.instruments.append(instrument)
    midi.write(output_file)
    print(instrument.notes)
    print(f"Generated MIDI saved as {output_file}")
sequence_to_midi(generated_sequence[0], "generated.mid")


[Note(start=3.002605, end=3.092559, pitch=69, velocity=77), Note(start=2.808693, end=2.882148, pitch=71, velocity=70), Note(start=5.091476, end=5.125166, pitch=62, velocity=77), Note(start=1.087364, end=1.151993, pitch=69, velocity=85), Note(start=6.435944, end=6.434053, pitch=64, velocity=78), Note(start=1.485779, end=1.582350, pitch=72, velocity=75), Note(start=3.862644, end=3.898870, pitch=63, velocity=80), Note(start=0.000000, end=0.092430, pitch=70, velocity=71), Note(start=1.497025, end=1.631539, pitch=71, velocity=68), Note(start=1.632755, end=1.613327, pitch=70, velocity=71), Note(start=2.226082, end=2.312410, pitch=70, velocity=71), Note(start=6.017235, end=6.098979, pitch=63, velocity=74), Note(start=3.872761, end=3.910241, pitch=69, velocity=76), Note(start=4.375931, end=4.416295, pitch=69, velocity=76), Note(start=4.674026, end=4.643908, pitch=68, velocity=73), Note(start=5.121834, end=5.165927, pitch=67, velocity=70)]
Generated MIDI saved as generated.mid
