### Import the necessary libraries 

In [1]:
import numpy as np 
import pandas as pd


In [2]:
df = pd.read_csv('/kaggle/input/bach-chorales-2/train/chorale_000.csv')

In [3]:
df

Unnamed: 0,note0,note1,note2,note3
0,74,70,65,58
1,74,70,65,58
2,74,70,65,58
3,74,70,65,58
4,75,70,58,55
...,...,...,...,...
187,70,65,62,46
188,70,65,62,46
189,70,65,62,46
190,70,65,62,46


In [4]:
import os 
data_dir = '/kaggle/input/bach-chorales-2'

train_dir = os.path.join(data_dir,'train')
test_dir = os.path.join(data_dir,'test')
valid_dir = os.path.join(data_dir,'valid')


In [5]:
train_files = sorted(
    [os.path.join(train_dir, f) for f in os.listdir(train_dir) if f.endswith('.csv')]
)

test_files = sorted(
    [os.path.join(test_dir, f) for f in os.listdir(test_dir) if f.endswith('.csv')]
)

valid_files = sorted(
    [os.path.join(valid_dir, f) for f in os.listdir(valid_dir) if f.endswith('.csv')]
)


In [6]:
train_data = [pd.read_csv(f).values.tolist() for f in train_files]
test_data = [pd.read_csv(f).values.tolist() for f in test_files]
valid_data = [pd.read_csv(f).values.tolist() for f in valid_files]

In [7]:
from music21 import stream, chord 

chorale = train_data[20]

s = stream.Stream()
for row in chorale:
    s.append(chord.Chord([n for n in row if n], quarterLenght=1))


s.show('midi')

### Preprocessing 

In [8]:
min_note, max_note = 36, 81

window_size, window_offset, batch_size = 32, 16, 32 

def make_xy(chorales):
    windows = [c[i:i + window_size + 1] for c in chorales for i in range(0, len(c) - window_size, window_offset)]
    data = np.array(windows, dtype=int)
    
    data = np.where(data == 0, 0, data -min_note +1)
    data = np.clip(data, 0, max_note - min_note + 1)

    flat = data.reshape(data.shape[0], -1)

    return flat[:, :-1], flat[:,1:]

X_train, y_train = make_xy(train_data)
X_test, y_test = make_xy(test_data)
X_valid, y_valid = make_xy(valid_data)
    

### Training the Model 

In [9]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, Dense, Embedding, LSTM, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Nadam




2025-10-07 19:41:12.219328: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1759866072.494791      13 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759866072.575147      13 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [10]:
model = Sequential()

model.add(Embedding(input_dim=47, output_dim=5, input_shape=[None] ))
model.add(Conv1D(32, kernel_size=2, padding='causal', activation='relu'))
model.add(BatchNormalization())
model.add(Conv1D(48, kernel_size=2, padding='causal', activation='relu', dilation_rate=2))
model.add(BatchNormalization())
model.add(Conv1D(64, kernel_size=2, padding='causal', activation='relu', dilation_rate=4))
model.add(BatchNormalization())
model.add(Conv1D(96, kernel_size=2, padding='causal', activation='relu', dilation_rate=8))
model.add(BatchNormalization())
model.add(Conv1D(128, kernel_size=2, padding='causal', activation='relu', dilation_rate=16))
model.add(BatchNormalization())
model.add(Dropout(0.05))
model.add(LSTM(256, return_sequences=True))
model.add(Dense(47, activation='softmax'))

model.summary()

  super().__init__(**kwargs)
2025-10-07 19:41:28.913355: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [11]:
optimizer = Nadam(learning_rate=1e-3)
model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
model.fit(X_train, y_train, epochs=20, validation_data=[X_valid,y_valid], batch_size=batch_size)

Epoch 1/20
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m50s[0m 385ms/step - accuracy: 0.3293 - loss: 2.6238 - val_accuracy: 0.0688 - val_loss: 3.6868
Epoch 2/20
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 381ms/step - accuracy: 0.7645 - loss: 0.9056 - val_accuracy: 0.1043 - val_loss: 3.8363
Epoch 3/20
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 382ms/step - accuracy: 0.7966 - loss: 0.7257 - val_accuracy: 0.1939 - val_loss: 3.2377
Epoch 4/20
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 381ms/step - accuracy: 0.8127 - loss: 0.6425 - val_accuracy: 0.2783 - val_loss: 2.4717
Epoch 5/20
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 381ms/step - accuracy: 0.8251 - loss: 0.5892 - val_accuracy: 0.5482 - val_loss: 1.4752
Epoch 6/20
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 377ms/step - accuracy: 0.8356 - loss: 0.5483 - val_accuracy: 0.7690 - val_loss: 0.8086
Epoch 7/20
[1m98/98[

<keras.src.callbacks.history.History at 0x7ecc9eb1b4d0>

In [12]:
def sample_next_note(probs):
    probabilities = np.asarray(probs, dtype=float)

    prob_sum = probabilities.sum()

    if prob_sum <= 0 or not np.isfinite(prob_sum):
        return int(np.argmax(probabilities))
    probabilities/= prob_sum


    return np.random.choice(len(probabilities), p=probabilities)

In [13]:
def generate_chorale(model, seed_chrods, length):
    token_sequence = np.array(seed_chrods, dtype=int)
    token_sequence = np.where(token_sequence ==0, 0, token_sequence - min_note + 1)
    token_sequence = token_sequence.reshape(1, -1)

    for _ in range(length * 4):
        next_token_probabilities = model.predict(token_sequence,verbose=0)[0,-1]
        next_token = sample_next_note(next_token_probabilities)
        token_sequence = np.concatenate([token_sequence, [[next_token]]], axis = 1)
        
    token_sequence = np.where(token_sequence ==0, 0, token_sequence + min_note - 1)

    return token_sequence.reshape(-1,4)
    

In [14]:
seed_chords = test_data[2]

chorale = seed_chords
s = stream.Stream()

for row in chorale:
    s.append(chord.Chord([n for n in row if n], quarterLength=1))
s.show('midi')

In [15]:
seed_chords = test_data[2][:8]

new_chorale = generate_chorale(model, seed_chords, 56)

In [16]:
new_chorale

array([[73, 68, 61, 53],
       [73, 68, 61, 53],
       [73, 68, 61, 53],
       [73, 68, 61, 53],
       [69, 66, 61, 54],
       [69, 66, 61, 54],
       [69, 66, 61, 54],
       [69, 66, 61, 54],
       [69, 66, 61, 54],
       [69, 66, 61, 54],
       [69, 66, 61, 54],
       [69, 66, 61, 54],
       [69, 66, 62, 50],
       [69, 66, 62, 50],
       [69, 66, 62, 51],
       [69, 66, 62, 51],
       [69, 65, 62, 50],
       [69, 65, 62, 50],
       [69, 66, 62, 48],
       [69, 66, 62, 48],
       [69, 64, 62, 46],
       [69, 64, 62, 46],
       [69, 62, 60, 45],
       [69, 62, 59, 45],
       [69, 62, 57, 45],
       [69, 62, 57, 45],
       [69, 64, 57, 45],
       [71, 64, 57, 45],
       [71, 64, 57, 45],
       [71, 64, 57, 45],
       [71, 64, 57, 52],
       [71, 64, 57, 52],
       [71, 64, 56, 52],
       [71, 64, 56, 52],
       [71, 64, 56, 52],
       [71, 64, 56, 52],
       [64, 59, 56, 52],
       [64, 59, 56, 52],
       [64, 59, 56, 52],
       [64, 59, 56, 52],


In [17]:
chorale = new_chorale.tolist()
s = stream.Stream()

for row in chorale:
    s.append(chord.Chord([n for n in row if n], quarterLength=1))
s.show('midi')

In [18]:
def generate_random_chorale(length, rest_probability=0.2, pitch_low=36, pitch_high=81, seed=None):
    rng = np.random.default_rng(seed)  # random number generator
    random_pitches = rng.integers(pitch_low, pitch_high + 1, size=(length, 4))  # generate random notes

    # some masking to have both silence and random pitches
    rest_mask = rng.random((length, 4)) < float(rest_probability)
    chorale = np.where(rest_mask, 0, random_pitches).astype(int)
    
    return chorale

In [19]:
# listen to completely random music to compare the quality to what our model generated
chorale = generate_random_chorale(56).tolist()
s = stream.Stream()
for row in chorale:
    s.append(chord.Chord([n for n in row if n], quarterLength=1))
s.show('midi')