<a href="https://www.kaggle.com/code/sharabhojha/chord-generation-lstm-example?scriptVersionId=222966548" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [3]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [4]:
# Preprocessing steps:
# 1: condense onset files to just onset times and onset notes

import os
import re
import pandas as pd

# Directory path to your annotations folder
annotations_dir = "/kaggle/input/aam-annotations/AAM-annotations/"
num = 0

# Iterate through all files in the directory
for filename in os.listdir(annotations_dir):
    if "onsets" in filename and filename.endswith(".arff"):  # Ensure it's an ARFF file with 'onsets' in its name
        num += 1
        file_path = os.path.join(annotations_dir, filename)
        
        # Read the ARFF file
        with open(file_path, "r", encoding="utf-8") as file:
            lines = file.readlines()

        attributes = []
        data_start = False
        data_rows = []

        for line in lines:
            line = line.strip()
            
            # Ignore comments and empty lines
            if not line or line.startswith("%"):
                continue

            if line.lower().startswith("@attribute"):
                # Extract attribute name (between the first space and the last space)
                match = re.match(r"@attribute\s+['\"]?([\w\s]+)['\"]?\s+.*", line, re.IGNORECASE)
                if match:
                    attributes.append(match.group(1).strip())

            elif line.lower().startswith("@data"):
                data_start = True  # Data section starts
            
            elif data_start:
                # Split considering quoted strings properly
                values = re.findall(r'\".*?\"|\'.*?\'|[^,]+', line)
                values = [v.strip("\"' ") if v.strip() else None for v in values]  # Remove extra quotes
                data_rows.append(values)

        # Convert to DataFrame
        df = pd.DataFrame(data_rows, columns=attributes)

        # Convert numeric columns where possible
        for col in df.columns:
            try:
                df[col] = pd.to_numeric(df[col])  # Convert if possible
            except ValueError:
                pass  # Keep as string if conversion fails

        all_onsets = []

        # collect all onset events per timestamp
        for row in range(df.index.size):
            notes_at_onset = []
            for col in range(1, df.columns.size):
                notes_at_onset.append(df.iat[row, col])
            notes_at_onset = re.findall("(\d+)", ''.join(notes_at_onset))
            all_onsets.append([int(item) for item in notes_at_onset])

        # delete onset events for individual instruments and add column for all events
        allCols = df.columns[df.apply(lambda col: col.astype(str).str.contains(r"\[", regex=True)).any()].tolist()
        df.drop(allCols, axis=1, inplace=True)
        df["Onset events"] = all_onsets

        # Save the processed DataFrame to a new CSV file
        output_file = re.search("(\d+)", filename).group(0) + "_onset_condensed.csv"
        df.to_csv(output_file, index=False)

        if num % 100 == 0:
            print(f"Processed {filename} and saved to {output_file}")

Processed 0521_onsets.arff and saved to 0521_onset_condensed.csv
Processed 2699_onsets.arff and saved to 2699_onset_condensed.csv
Processed 2634_onsets.arff and saved to 2634_onset_condensed.csv
Processed 0544_onsets.arff and saved to 0544_onset_condensed.csv
Processed 1658_onsets.arff and saved to 1658_onset_condensed.csv
Processed 1624_onsets.arff and saved to 1624_onset_condensed.csv
Processed 1192_onsets.arff and saved to 1192_onset_condensed.csv
Processed 0001_onsets.arff and saved to 0001_onset_condensed.csv
Processed 0382_onsets.arff and saved to 0382_onset_condensed.csv
Processed 1000_onsets.arff and saved to 1000_onset_condensed.csv
Processed 0541_onsets.arff and saved to 0541_onset_condensed.csv
Processed 1630_onsets.arff and saved to 1630_onset_condensed.csv
Processed 0696_onsets.arff and saved to 0696_onset_condensed.csv
Processed 0121_onsets.arff and saved to 0121_onset_condensed.csv
Processed 0097_onsets.arff and saved to 0097_onset_condensed.csv
Processed 0220_onsets.arf

In [5]:
# 2: encode chord names and replace said chord names with encodings in beatinfo files

# Directory path to your annotations folder
headers = ['Start time in seconds', 'Bar count', 'Quarter count', 'Chord name']

chords = set()

dataframes = []
filenames = []

# Iterate through all files in the directory
for filename in os.listdir(annotations_dir):
    if "beatinfo" in filename and filename.endswith(".arff"):  # Ensure it's an ARFF file with 'beatinfo' in its name
        file_path = os.path.join(annotations_dir, filename)
        filenames.append(filename)
        df = pd.read_csv(file_path, comment='@', header=None)
        df.columns = headers

        for i in range(df.index.size):
            df.iat[i, 3] = df.iat[i, 3].replace("'", "")
            if df.iat[i, 3] == "BASS_NOTE_EXCEPTION":
                df.iat[i, 3] = "N.C."
            chords.add(df.iat[i, 3])
    
        dataframes.append(df)

sorted_chords = sorted(list(chords))
chord_encodings = dict(zip([i for i in range(len(chords))], sorted_chords))
print(chord_encodings)

# Modify existing dataframes to match encodings
for i in range(len(dataframes)):
    dataframes[i].drop(columns=['Bar count', 'Quarter count'], inplace=True)

    for j in range(dataframes[i].index.size):
        
        dataframes[i].iat[j, 1] = sorted_chords.index(dataframes[i].iat[j, 1])

    dataframes[i].to_csv(filenames[i].replace('arff', 'csv'), index=False)

{0: 'A#maj', 1: 'A#min', 2: 'Amaj', 3: 'Amin', 4: 'Bmaj', 5: 'Bmin', 6: 'C#maj', 7: 'C#min', 8: 'Cmaj', 9: 'Cmin', 10: 'D#maj', 11: 'D#min', 12: 'Dmaj', 13: 'Dmin', 14: 'Emaj', 15: 'Emin', 16: 'F#maj', 17: 'F#min', 18: 'Fmaj', 19: 'Fmin', 20: 'G#maj', 21: 'G#min', 22: 'Gmaj', 23: 'Gmin', 24: 'N.C.'}


In [6]:
# visualize the files
working_dir = "/kaggle/working/"
onsets = pd.read_csv(working_dir + "0001_onset_condensed.csv")
print(onsets.head())
beatinfo = pd.read_csv(working_dir + "0001_beatinfo.csv")
print(beatinfo.head())

def align_onsets_with_chords(onsets, beatinfo):
    aligned_data = []
    for _, onset_row in onsets.iterrows():
        onset_time = onset_row['Onset time in seconds']
        # Find the chord corresponding to this onset time
        chord_row = beatinfo[beatinfo['Start time in seconds'] <= onset_time].iloc[-1]
        onset_list = eval(onset_row['Onset events'])
        aligned_data.append((onset_list, chord_row['Chord name']))
    return aligned_data

   Onset time in seconds  Onset events
0               0.000000  [41, 60, 65]
1               0.326086      [41, 60]
2               0.652173  [41, 65, 65]
3               0.978259  [41, 65, 69]
4               1.304346  [41, 65, 65]
   Start time in seconds  Chord name
0               0.000000          18
1               0.652174          18
2               1.304348          18
3               1.956522          18
4               2.608696           0


In [7]:
# create aligned data for every onset and beatinfo file

all_data = []

for filename in os.listdir(working_dir):
    if "onset" in filename:
        onset_path = os.path.join(working_dir, filename)
        beatinfo_path = os.path.join(working_dir, re.search("(\d+)", filename).group(0) + "_beatinfo.csv")
        onsets = pd.read_csv(onset_path)
        beatinfo = pd.read_csv(beatinfo_path)
        all_data += align_onsets_with_chords(onsets, beatinfo)

print(all_data[0:20])

[([70, 36, 42, 39, 58, 63, 66, 61], 11.0), ([63, 42, 39, 58, 63, 66, 61], 11.0), ([70, 40, 42, 39, 58, 63, 66, 59], 11.0), ([66, 36, 42, 39, 58, 63, 66], 11.0), ([70, 36, 42, 39, 58, 63, 66, 59], 11.0), ([63, 42, 39, 58, 63, 66, 65], 11.0), ([66, 36, 40, 42, 39, 58, 63, 66, 59], 11.0), ([63, 42, 39, 58, 63, 66, 59], 11.0), ([65, 36, 42, 46, 58, 61, 65, 59], 1.0), ([70, 42, 46, 58, 61, 65], 1.0), ([65, 40, 42, 46, 58, 61, 65, 59], 1.0), ([61, 36, 42, 46, 58, 61, 65], 1.0), ([65, 36, 42, 46, 58, 61, 65, 61], 1.0), ([70, 42, 46, 58, 61, 65, 61], 1.0), ([65, 36, 40, 42, 46, 58, 61, 65, 59], 1.0), ([61, 42, 46, 58, 61, 65], 1.0), ([68, 36, 42, 37, 56, 61, 65, 59], 6.0), ([61, 42, 37, 56, 61, 65, 65], 6.0), ([68, 40, 42, 37, 56, 61, 65, 59], 6.0), ([65, 36, 42, 37, 56, 61, 65], 6.0)]


In [11]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Embedding, LSTM, Dense, TimeDistributed, BatchNormalization
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers import Adam

def create_chord_classification_model(vocab_size, embedding_dim, lstm_units, num_classes, max_sequence_length):
    # Input for note sequences
    note_input = Input(shape=(max_sequence_length,))
    
    # Embedding layer for note sequences
    note_embedding = Embedding(vocab_size, embedding_dim)(note_input)
    
    # LSTM layers
    lstm_output = LSTM(lstm_units, return_sequences=False, dropout=0.2, recurrent_dropout=0.2)(note_embedding) # with more lstm layers, return sequences = true
    lstm_output = BatchNormalization()(lstm_output)
    #lstm_output = LSTM(lstm_units, return_sequences=False, dropout=0.2, recurrent_dropout=0.2)(lstm_output)
    
    # Output layer
    output = Dense(num_classes, activation='softmax', kernel_regularizer=l2(0.01))(lstm_output)
    
    model = Model(inputs=note_input, outputs=output)
    return model

# Hyperparameters
vocab_size = 128  # Assuming MIDI note range
embedding_dim = 32
lstm_units = 64
num_classes = 25  # Number of chord classes
max_sequence_length = 4  # Adjust based on your data

# Create the model
model = create_chord_classification_model(vocab_size, embedding_dim, lstm_units, num_classes, max_sequence_length)

# Compile the model
optimizer = Adam(learning_rate=0.0001)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

# Function to prepare data
def prepare_data(data, max_sequence_length):
    X = []
    y = []
    for sequence in data:
        notes, chord = sequence
        padded_notes = tf.keras.preprocessing.sequence.pad_sequences([notes], maxlen=max_sequence_length, padding='post', truncating='post')[0]
        X.append(padded_notes)
        y.append(chord)
    return np.array(X), np.array(y)

# Prepare your data
X, y = prepare_data(all_data, max_sequence_length)

# Convert y to one-hot encoded format
y_onehot = tf.keras.utils.to_categorical(y, num_classes=num_classes)

# Train the model
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
history = model.fit(X, y_onehot, validation_split=0.2, epochs=100, batch_size=32, callbacks=[early_stopping])

# Function for inference
def predict_chord(model, note_sequence):
    padded_sequence = tf.keras.preprocessing.sequence.pad_sequences([note_sequence], maxlen=max_sequence_length, padding='post', truncating='post')
    predictions = model.predict(padded_sequence)
    return np.argmax(predictions[0])  # Return the prediction

# Example usage
sample_sequence = [60, 64, 67, 72]  # C major chor
predicted_chord = predict_chord(model, sample_sequence)
print(f"Predicted chord num: {predicted_chord}")
print(f"Predicted chord: {chord_encodings[predicted_chord]}")

Epoch 1/100
[1m42808/42808[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m192s[0m 4ms/step - accuracy: 0.6353 - loss: 1.5936 - val_accuracy: 0.6959 - val_loss: 1.1297
Epoch 2/100
[1m42808/42808[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m192s[0m 4ms/step - accuracy: 0.6884 - loss: 1.1828 - val_accuracy: 0.7060 - val_loss: 1.0550
Epoch 3/100
[1m42808/42808[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m187s[0m 4ms/step - accuracy: 0.6984 - loss: 1.1077 - val_accuracy: 0.7113 - val_loss: 1.0205
Epoch 4/100
[1m15103/42808[0m [32m━━━━━━━[0m[37m━━━━━━━━━━━━━[0m [1m1:47[0m 4ms/step - accuracy: 0.7037 - loss: 1.0659

KeyboardInterrupt: 