# Generate chords using the MIDI-Chord-Gen Model

- MIDI-Chord-Gen is a transformer based generative model.
- It takes a primer midi file to generate the rest of the chord progression.
- It has a large and diverse dataset of midi files to learn from.

**This particular Notebook uses a pretrained model's weights to generate chords**
>*note: that this is not a fine-tuned model so it will generate chord progressions that tend to hallucinate.*

# Setup Environment and Data Structure

## Setup Environment

In [1]:
from google.colab import drive
drive.mount('/content/drive')
proj_folder_path = '/content/drive/MyDrive/Colab Notebooks/MIDI-gen-notebooks/'
FULL_MIDI_LANG_DATA_path = '/content/drive/MyDrive/Colab Notebooks/MIDI-gen-notebooks/Large_MIDI_Language_Base.parquet'

Mounted at /content/drive


In [2]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from music21 import chord, note, stream, clef, meter,converter, midi
import os
from tqdm import tqdm
import datetime
from IPython.display import clear_output, Image, display
import PIL
import pandas as pd
import re
import fractions as fract
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

## Structure the Data

In [3]:
#@title Open MIDI File
def open_midi(file_path):
    try:
        midi = converter.parse(file_path)
        return midi
    except Exception as e:
        print(f"An error occurred while opening the MIDI file: {e}")

In [4]:
#@title Filter vocabulary
def filter_data(input_list, prev_index):
    filtered_list = []
    part_start_flag = False

    for element in input_list[prev_index:-1]:
        if element == '<part_start>':
            part_start_flag = True
        elif element == '<chord_meta>' and part_start_flag:
            filtered_list.append('<part_start>')
            filtered_list.append(element)
            part_start_flag = False
        elif part_start_flag is False:
            filtered_list.append(element)

    return filtered_list

In [5]:
#@title Create Vocabulary for primer
def create_primer_vocab(midi_data, genre):
    prev_index =0
    vocab =[]
    vocab.append("<song_start>")
    vocab.append(f"<{genre}>")
    # if _artist_name is not None:
    #     vocab.append(f"<artist {_artist_name}>")
    # vocab.append(f"<song {_song_name}>")
    if midi_data is not None:
        for i, part in enumerate(midi_data.parts):
            vocab.append(f"<part_start>")
            for element in part.recurse():
                # if isinstance(element, meter.TimeSignature):
                #     vocab.append(element.ratioString)
                if isinstance(element, chord.Chord):
                    vocab.append("<chord_meta>")
                    vocab.append(
                        f"chord_quarterlength {element.duration.quarterLength.real}")
                    vocab.append(
                        f"chord_offset {element.offset}")
                    vocab.append("<chord_start>")
                    for note_i, n in enumerate(element):
                        vocab.append(f"note_{note_i}")
                        vocab.append(
                            f"note_pitch value_{n.pitch.midi}")
                        vocab.append(
                            f"note_velocity value_{n.volume.velocity}")
                        vocab.append(
                            f"note_quarterlength value_{n.duration.quarterLength}")
                        vocab.append(
                            f"note_offset value_{n.offset}")
                    vocab.append("<chord_end>")
        vocab.append("<song_end>")
        # vocab[prev_index:-
        #               1] = filter_data(vocab, prev_index)
        # prev_index = len(vocab)
    return vocab

# The Model

In [6]:
FULL_MIDI_LANG_DATA_list = pd.read_parquet(FULL_MIDI_LANG_DATA_path)["full_MIDI_lang_base"].values.tolist()

In [7]:
len(FULL_MIDI_LANG_DATA_list)

20940575

In [8]:
#@title Hyper Parameters
unique_chars = sorted(list(set(FULL_MIDI_LANG_DATA_list)))

vocab_size = len(unique_chars)
block_size = 128
max_length = 128
batch_size = 32


n_embd = 512
n_heads = 8
forward_expansion = 4 # multiple with n_embd
num_layers = 6
dim_feedforward = n_embd*forward_expansion

learning_rate = 0.0001

dropout = 0.4

epochs = 100
max_iters = 500
eval_interval = 50
eval_iters = 50

device = 'cuda' if torch.cuda.is_available() else 'cpu'
attention_weights = []

In [9]:
#@title Dictionary, decoder and encoder of the main vocabulary
stoi = {ch: i for i, ch in enumerate(unique_chars)}
itos = {i: ch for i, ch in enumerate(unique_chars)}
def encode(s): return [stoi[c] for c in s]
def decode(l): return '\n'.join([itos[i] for i in l])

In [10]:
#@title Get Batch Function

MIDI_DATA_TENSOR = torch.tensor(encode(FULL_MIDI_LANG_DATA_list), dtype=torch.long)

n = int(0.9*len(MIDI_DATA_TENSOR))  # first 90% will be train, rest val
train_data = MIDI_DATA_TENSOR[:n]
val_data = MIDI_DATA_TENSOR[n:]
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [11]:
#@title Self Attention class
class SelfAttention(nn.Module):
    def __init__(self):
        super(SelfAttention, self).__init__()
        self.embed_size = n_embd
        self.heads = n_heads
        self.head_dim = n_embd // n_heads

        assert (
            self.head_dim * n_heads == n_embd
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(n_heads * self.head_dim, n_embd)

    def forward(self, values, keys, query, mask=None):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)
        # values, keys, queries: (N, S, H, E/H)        (32, 20, 8, 8)


        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # energy: (N, H, S, S)                         (32, 8, 20, 20)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float('-inf'))



        attention = torch.softmax(energy / (self.embed_size ** 0.5), dim=3)
        # attention: (N, H, S, S)                      (32, 8, 20, 20)


        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        # out: (N, S, E)    (32, 20, 64)


        out = self.fc_out(out)
        # out: (N, S, E)    (32, 20, 64)
        return out


In [12]:
#@title Decoderblock Class
class DecoderBlock(nn.Module):
    def __init__(self):
        super(DecoderBlock, self).__init__()
        self.attention = SelfAttention()
        self.norm1 = nn.LayerNorm(n_embd)
        self.norm2 = nn.LayerNorm(n_embd)

        self.feed_forward = nn.Sequential(
            nn.Linear(n_embd, forward_expansion * n_embd),
            # self.feed_forward[0](x): (N, S, E*F)  (32, 20, 2048)
            nn.ReLU(),
            # self.feed_forward[2](self.feed_forward[0](x)): (32, 20, 64)
            nn.Linear(forward_expansion * n_embd, n_embd),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask=None):
        # query, key, value: (N, S, E)       (32, 20, 64)

        attention = self.attention(value, key, query, mask)
        # attention: (N, S, E)               (32, 20, 64)

        x = self.dropout(self.norm1(attention + query))
        # x: (N, S, E)                       (32, 20, 64)

        forward = self.feed_forward(x) # -> runs a sequential class
        # forward: (N, S, E)                 (32, 20, 64)

        out = self.dropout(self.norm2(forward + x))
        # out: (N, S, E)                     (32, 20, 64)
        return out


In [13]:
#@title MIDI Chord Gen Class
class MIDIGenModel(nn.Module):
    def __init__(self):
        super(MIDIGenModel, self).__init__()

        self.embed_size = n_embd
        self.device = device

        self.word_embedding = nn.Embedding(vocab_size, n_embd)
        # self.word_embedding(x): (N, S, E)             (32, 20, 64)

        self.position_embedding = nn.Embedding(max_length, n_embd)
        # self.position_embedding(positions): (N, S, E) (32, 20, 64)


        self.layers = nn.ModuleList(
            [
                DecoderBlock()
                for _ in range(num_layers)
            ]
        )

        self.fc_out = nn.Linear(n_embd, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, targets=None):
        # x: (N, S)                                     (32, 20)
        N, seq_length = x.shape


        positions = (
            torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        )

        out = self.dropout(
            (self.word_embedding(x) + self.position_embedding(positions))
        )
        # out: (N, S, E)                                (32, 20, 64)
        mask = torch.tril(torch.ones((seq_length, seq_length))
               .type(torch.BoolTensor)
               ).to(self.device)

        for layer in self.layers:
            # each transformer block
            out = layer(out, out, out, mask) # out: (N, S, E) (32, 20, 64)

        out = self.fc_out(out) # (32, 20, 64)

        if targets is None:
            loss = None

        else:
            B, T, C = out.shape
            logits = out.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return out, loss

    def generate(self, idx, max_new_tokens, p=0.9):
        probs_list=[]
    # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)

            # Apply top-p nucleus sampling
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

            # Remove tokens with cumulative probability above the threshold
            sorted_indices_to_remove = cumulative_probs > p
            # Shift the indices to the right to keep also the first token above the threshold
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0

            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            logits[:, indices_to_remove] = float('-inf')

            # Sample from the remaining distribution
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            probs_list.append(probs.cpu().detach().numpy())
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx, probs_list




In [14]:
#@title Loss Function
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()

    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# Instatiate the Model

In [15]:
#@title Load Pretrained Weights
model_weights_path = "/content/drive/MyDrive/Colab Notebooks/MIDI-gen-notebooks/07-12-23_23:06_model_weights.pth"
model_gen = MIDIGenModel()
model_gen.load_state_dict(torch.load(model_weights_path))
model_gen = model_gen.to(device)

# Upload the Primer

In [41]:
#@title Run This block to upload file
from google.colab import files

print("Please upload a MIDI file:")

uploaded = files.upload()

for fn in uploaded.keys():
    midi_file = fn
# Process the file content here

Please upload a MIDI file:


Saving sample 2 midid.mid to sample 2 midid.mid


In [57]:
#@title Convert the primer to integer form
midi_data = open_midi(midi_file)
primer = [stoi[midi_word] for midi_word in create_primer_vocab(midi_data, "rock")]

In [58]:
#@title Convert the primer into a tensor object and send it to GPU
context = torch.tensor([primer],dtype=torch.long, device=device)

# Generate

In [59]:
#@markdown The chord progression length doesn't translate directly to the output because the length takes into account other tokens too.

#@markdown For example, "`[<note start>, <chord_meta_data>, <chord_end>, <chord_start>]` "

#@markdown So an approximate ratio would be 20 tokens to 1 chord
chord_progression_length = 500 #@param {type:"slider", min:40, max:1000, step:10}
p_amount = 0.95 #@param {type:"slider", min:0.85, max:0.95, step:0.01}
generated_out1, probs_1 = model_gen.generate(context, chord_progression_length,p_amount)
# generated_out2,probs_2  =model_gen.generate(context, chord_progression_length,p_amount)
# generated_out3, probs_3 =model_gen.generate(context, chord_progression_length,p_amount)

In [62]:
#@title plot probs
probs_1_0 = probs_1[-2].cpu().detach().numpy()
plt.figure(figsize=(12, 6))
plt.plot(probs_1_0)

plt.xlabel('Vocabulary ID')
plt.ylabel('Probability')
plt.show()

IndexError: ignored

In [63]:
generated_out1=generated_out1[0].tolist()

In [64]:
#@title Make MIDI
def make_midi(generated_chords:list):
    note_index_pattern = r'note_(\d+)'
    value_pattern = r'value_([\d\.]+)'

    chord_meta_flag =False
    chord_meta_index = None

    chord_num = 0
    chord_dict = {}

    chord_start_flag = False
    chord_start_index = None

    note_flag = False
    note_num = None

    for i, token in enumerate(generated_chords):
        if itos[token] == "<chord_meta>" and chord_meta_flag == False and chord_meta_index==None:

            chord_meta_flag=True
            chord_meta_index=i
            chord_num+=1
            chord_dict[chord_num] = {}

        elif itos[token] == "<chord_start>" and chord_meta_flag == False and chord_meta_index == None and chord_start_flag ==False and chord_start_index == None:
            chord_start_flag = True
            chord_start_index = i

            chord_meta_flag=True
            chord_meta_index=i
            chord_num+=1
            chord_dict[chord_num] = {}



        if itos[token].startswith("chord_quarterlength") and chord_meta_flag == True and chord_meta_index is not None and i > chord_meta_index:
            # chord_obj.duration.quarterLength = float(itos[token].split()[-1])
            chord_dict[chord_num]["chord_quarterlength"] = float(fract.Fraction(itos[token].split()[-1]))

        if itos[token].startswith("chord_offset") and chord_meta_flag == True and chord_meta_index is not None:

            chord_dict[chord_num]["chord_offset"]=float(fract.Fraction(itos[token].split()[-1]))

        if itos[token] == "<chord_start>" and chord_meta_flag == True and chord_meta_index is not None:
            chord_start_flag = True
            chord_start_index = i
            # chord_num+=1

        if re.match(note_index_pattern, itos[token]) and chord_start_flag == True and chord_start_index is not None:
            note_num = int(itos[token][5:])
            chord_dict[chord_num][note_num] ={}

        if itos[token].startswith("note_pitch") and chord_start_flag == True and chord_start_index is not None and note_num is not None:
            res = re.search(value_pattern, itos[token])
            if res:
                chord_dict[chord_num][note_num]["note_pitch"]=int(res.group(1))

        if itos[token].startswith("note_velocity") and chord_start_flag == True and chord_start_index is not None and note_num is not None:
            res = re.search(value_pattern, itos[token])
            if res:
                chord_dict[chord_num][note_num]["note_velocity"]=int(res.group(1))

        if itos[token].startswith("note_quarterlength") and chord_start_flag == True and chord_start_index is not None and note_num is not None:
            res = re.search(value_pattern, itos[token])
            if res:
                chord_dict[chord_num][note_num]["note_quarterlength"]=float(fract.Fraction(res.group(1)))

        if itos[token].startswith("note_offset") and chord_start_flag == True and chord_start_index is not None and note_num is not None:
            res = re.search(value_pattern, itos[token])
            if res:
                chord_dict[chord_num][note_num]["note_offset"]=float(fract.Fraction(res.group(1)))

        if itos[token] == "<chord_end>" and chord_meta_flag == True and chord_meta_index is not None and chord_start_flag ==True and chord_start_index is not None and note_num is not None:
            chord_meta_flag=False
            chord_meta_index=None

            chord_start_flag=False
            chord_start_index=None
            note_num=None

    return chord_dict

In [65]:
mdi_dict1 = make_midi(generated_out1)
#mdi_dict2,mdi_dict3=make_midi(generated_out2),make_midi(generated_out3)

In [66]:
from pprint import pprint
pprint(mdi_dict1)

{1: {0: {'note_offset': 0.0,
         'note_pitch': 62,
         'note_quarterlength': 1.0,
         'note_velocity': 100},
     1: {'note_offset': 0.0,
         'note_pitch': 67,
         'note_quarterlength': 1.0,
         'note_velocity': 100},
     2: {'note_offset': 0.0,
         'note_pitch': 71,
         'note_quarterlength': 1.0,
         'note_velocity': 100},
     'chord_offset': 0.0,
     'chord_quarterlength': 1.0},
 2: {0: {'note_offset': 0.0,
         'note_pitch': 62,
         'note_quarterlength': 1.0,
         'note_velocity': 100},
     1: {'note_offset': 0.0,
         'note_pitch': 66,
         'note_quarterlength': 1.0,
         'note_velocity': 100},
     2: {'note_offset': 0.0,
         'note_pitch': 69,
         'note_quarterlength': 1.0,
         'note_velocity': 100},
     'chord_offset': 1.0,
     'chord_quarterlength': 1.0},
 3: {0: {'note_offset': 0.0,
         'note_pitch': 60,
         'note_quarterlength': 1.0,
         'note_velocity': 100},
     1: {'no

In [67]:
s = stream.Stream()
for chord_ in mdi_dict1.items():
    c = chord.Chord()
    chord_meta = {}
    notes_list = []
    for meta_data_ in chord_[1].items():
        n = note.Note()
        if isinstance(meta_data_[0], int):
            if "note_pitch" in meta_data_[1].keys():
                n.pitch.midi= int(meta_data_[1]["note_pitch"])
                try:
                    n.volume.velocity = meta_data_[1]["note_velocity"]
                    n.duration.quarterLength =  float(fract.Fraction(meta_data_[1]["note_quarterlength"]))
                    n.offset = float(fract.Fraction(meta_data_[1]["note_quarterlength"]))
                except KeyError:
                    pass
        else:
            if meta_data_[0] == "chord_quarterlength":
                chord_meta["chord_quarterlength"] = float(fract.Fraction(meta_data_[1]))
            elif meta_data_[0] == "chord_offset":
                chord_meta["chord_offset"] = float(fract.Fraction(meta_data_[1]))

        notes_list.append(n)

    if len(notes_list) >= 2:
        # print([n.pitch for n in notes_list])
        c.add(notes_list)
        try:
            c.duration.quarterLength = chord_meta["chord_quarterlength"]
            c.offset = chord_meta["chord_offset"]
        except KeyError:
            pass
        s.append(c)
    else:
        continue
s.write('midi', fp='F_major_primer_draft3.mid')


'F_major_primer_draft3.mid'