## Retrieve YM2413-MDB (v1.0.2) Dataset

In [None]:
import pathlib

# Define the directory containing MIDI files
midi_files_directory = pathlib.Path('../music_dataset/YM2413-MDB-v1.0.2/midi/adjust_tempo_remove_delayed_inst')

# Get a list of MIDI files
midi_file_paths = list(midi_files_directory.glob('*.mid*'))
# print('Number of MIDI files:', len(midi_file_paths))

## Inspect Files

In [None]:
# import os
# import csv
# import pretty_midi

# # Set maximum tick value for pretty_midi
# pretty_midi.pretty_midi.MAX_TICK = 1e16

# # Initialize list to store inspection results
# inspection_results = []

# for midi_file_path in midi_file_paths:
#     # Convert Path object to string
#     midi_file_path_str = str(midi_file_path)
    
#     # Process MIDI file
#     midi_data = pretty_midi.PrettyMIDI(midi_file_path_str)
#     num_instruments = len(midi_data.instruments)
#     instrument_names = [pretty_midi.program_to_instrument_name(inst.program) for inst in midi_data.instruments]

#     # Store inspection results
#     inspection_results.append({
#         "Filename": os.path.basename(midi_file_path_str),
#         "Number of Instruments": num_instruments,
#         "Instrument Names": instrument_names
#     })

# # Write inspection results to CSV file
# csv_output_file = '../ym2413_jupyter_xt/inspect.csv'
# csv_columns = ["Filename", "Number of Instruments", "Instrument Names"]

# with open(csv_output_file, "w", newline="") as csvfile:
#     writer = csv.DictWriter(csvfile, fieldnames=csv_columns)
#     writer.writeheader()
#     for result in inspection_results:
#         writer.writerow(result)

# print("Inspection results saved to", csv_output_file)

## 1. Convert MIDI to Nintendo Entertainment System (NES) Format

In [None]:
import itertools
import os
import random
import pretty_midi
import csv

# Define the pitch ranges for NES instruments
nes_instrument_name_to_min_pitch = {
    'p1': 33,
    'p2': 33,
    'tr': 21
}
nes_instrument_name_to_max_pitch = {
    'p1': 108,
    'p2': 108,
    'tr': 108
}

def is_instrument_monophonic(instrument):
    notes = instrument.notes
    last_note_start = -1
    for note in notes:
        assert note.start >= last_note_start
        last_note_start = note.start

    monophonic = True
    for i in range(len(notes) - 1):
        note0 = notes[i]
        note1 = notes[i + 1]
        if note0.end > note1.start:
            monophonic = False
            break
    return monophonic

def generate_nesmdb_midi_examples(
    midi_filepath,
    output_directory,
    min_num_instruments=1,
    min_length_seconds=5.,
    max_length_seconds=600.,
    filter_bad_times=True,
    min_pitch=21,
    max_pitch=108,
    filter_duplicates=True,
    include_drums=True,
    max_examples=16,
    max_duration_seconds=180.,
    emotion_mapping=None):

    midi_name = os.path.splitext(os.path.basename(midi_filepath))[0]

    if min_num_instruments <= 0:
        raise ValueError()

    if os.path.getsize(midi_filepath) > (512 * 1024): # 512K
        return

    try:
        midi_data = pretty_midi.PrettyMIDI(midi_filepath)
    except:
        return

    midi_length = midi_data.get_end_time()
    if midi_length < min_length_seconds or midi_length > max_length_seconds:
        return

    for instrument in midi_data.instruments:
        for note in instrument.notes:
            if filter_bad_times:
                if note.start < 0 or note.end < 0 or note.end < note.start:
                    return
            note.start = round(note.start * 44100.) / 44100.
            note.end = round(note.end * 44100.) / 44100.

    instruments = midi_data.instruments
    drums = [i for i in instruments if i.is_drum]
    instruments = [i for i in instruments if not i.is_drum]

    instruments_in_range = []
    for instrument in instruments:
        pitches = [n.pitch for n in instrument.notes]
        min_pitch_instrument = min(pitches)
        max_pitch_instrument = max(pitches)
        if max_pitch_instrument >= min_pitch and min_pitch_instrument <= max_pitch:
            instruments_in_range.append(instrument)
    instruments = instruments_in_range
    if len(instruments) < min_num_instruments:
        return

    for instrument in instruments:
        instrument.notes = sorted(instrument.notes, key=lambda x: x.start)
    if include_drums:
        for instrument in drums:
            instrument.notes = sorted(instrument.notes, key=lambda x: x.start)

    instruments = [i for i in instruments if is_instrument_monophonic(i)]
    if len(instruments) < min_num_instruments:
        return

    if filter_duplicates:
        unique_notes = set()
        unique_instruments = []
        for instrument in instruments:
            pitches = ','.join(['{}:{:.1f}'.format(str(note.pitch), note.start) for note in instrument.notes])
            if pitches not in unique_notes:
                unique_instruments.append(instrument)
                unique_notes.add(pitches)
        instruments = unique_instruments
        if len(instruments) < min_num_instruments:
            return

    num_instruments = len(instruments)
    if num_instruments == 1:
        instrument_permutations = [(0, -1, -1), (-1, 0, -1), (-1, -1, 0)]
    elif num_instruments == 2:
        instrument_permutations = [(-1, 0, 1), (-1, 1, 0), (0, -1, 1), (0, 1, -1), (1, -1, 0), (1, 0, -1)]
    elif num_instruments > 32:
        instrument_permutations = list(itertools.permutations(random.sample(range(num_instruments), 32), 3))
    else:
        instrument_permutations = list(itertools.permutations(range(num_instruments), 3))

    if len(instrument_permutations) > max_examples:
        instrument_permutations = random.sample(instrument_permutations, max_examples)

    num_drums = len(drums) if include_drums else 0
    instrument_permutations_plus_drums = []
    for permutation in instrument_permutations:
        selection = -1 if num_drums == 0 else random.choice(range(num_drums))
        instrument_permutations_plus_drums.append(permutation + (selection,))
    instrument_permutations = instrument_permutations_plus_drums

    quarter_label = emotion_mapping.get(midi_name, 'Unknown')
    emotion_mapping_dict = {
        'Q1': 'happy',
        'Q2': 'angry',
        'Q3': 'sad',
        'Q4': 'relaxed'
    }
    emotion_label = emotion_mapping_dict.get(quarter_label, 'Other')
    output_subdir = f"{quarter_label}_{emotion_label}"
    file_output_directory = os.path.join(output_directory, output_subdir)
    os.makedirs(file_output_directory, exist_ok=True)

    for i, permutation in enumerate(instrument_permutations):
        lead1_program = pretty_midi.instrument_name_to_program('Lead 1 (square)')
        lead2_program = pretty_midi.instrument_name_to_program('Lead 2 (sawtooth)')
        bass_program = pretty_midi.instrument_name_to_program('Synth Bass 1')
        drum_program = pretty_midi.instrument_name_to_program('Breath Noise')
        lead1_instrument = pretty_midi.Instrument(program=lead1_program, name='p1', is_drum=False)
        lead2_instrument = pretty_midi.Instrument(program=lead2_program, name='p2', is_drum=False)
        bass_instrument = pretty_midi.Instrument(program=bass_program, name='tr', is_drum=False)
        drum_instrument = pretty_midi.Instrument(program=drum_program, name='no', is_drum=True)

        permutation_notes = []
        for midi_instrument_id, nes_instrument_name in zip(permutation, ['p1', 'p2', 'tr', 'no']):
            if midi_instrument_id < 0:
                permutation_notes.append(None)
            else:
                if nes_instrument_name == 'no':
                    midi_instrument = drums[midi_instrument_id]
                    valid_notes = midi_instrument.notes
                else:
                    midi_instrument = instruments[midi_instrument_id]
                    valid_notes = [n for n in midi_instrument.notes if n.pitch >= nes_instrument_name_to_min_pitch[nes_instrument_name] and n.pitch <= nes_instrument_name_to_max_pitch[nes_instrument_name]]
                permutation_notes.append(valid_notes)
        assert len(permutation_notes) == 4

        start_time = None
        end_time = None
        for notes in permutation_notes:
            if notes is None or len(notes) == 0:
                continue
            note_start = min([n.start for n in notes])
            note_end = max([n.end for n in notes])
            if start_time is None or note_start < start_time:
                start_time = note_start
            if end_time is None or note_end > end_time:
                end_time = note_end
        if start_time is None or end_time is None:
            continue

        if (end_time - start_time) > max_duration_seconds:
            end_time = start_time + max_duration_seconds

        for notes, instrument_name, instrument in zip(permutation_notes, ['p1', 'p2', 'tr', 'no'], [lead1_instrument, lead2_instrument, bass_instrument, drum_instrument]):
            if notes is None:
                continue

            if instrument_name == 'no':
                random_noise_mapping = [random.randint(1, 16) for _ in range(128)]

            last_note_end = -1
            for note in notes:
                velocity = note.velocity
                pitch = note.pitch
                note_start = note.start
                note_end = note.end

                if instrument_name == 'no' and note_start < last_note_end:
                    continue
                last_note_end = note_end

                assert note_start >= start_time
                if note_end > end_time:
                    continue
                assert note_end <= end_time

                velocity = 1 if instrument_name == 'tr' else int(round(1. + (14. * velocity / 127.)))
                assert velocity > 0
                if instrument_name == 'no':
                    pitch = random_noise_mapping[pitch]
                note_start = note_start - start_time
                note_end = note_end - start_time
                instrument.notes.append(
                    pretty_midi.Note(
                        velocity=velocity,
                        pitch=pitch,
                        start=note_start,
                        end=note_end
                    )
                )

        midi_output = pretty_midi.PrettyMIDI()
        for inst in [lead1_instrument, lead2_instrument, bass_instrument, drum_instrument]:
            if len(inst.notes) > 0:
                midi_output.instruments.append(inst)
        end_instrument = pretty_midi.Instrument(program=0, name='end', is_drum=False)
        end_instrument.notes.append(
            pretty_midi.Note(
                velocity=15,
                pitch=108,
                start=(end_time - start_time),
                end=(end_time - start_time) + .1
            )
        )
        midi_output.instruments.append(end_instrument)

        output_midi_filepath = os.path.join(file_output_directory, f"{midi_name}_{i}.mid")
        midi_output.write(output_midi_filepath)

def load_emotion_mapping(csv_filepath):
    emotion_mapping = {}
    with open(csv_filepath, mode='r') as infile:
        reader = csv.reader(infile)
        next(reader)
        for row in reader:
            filename = os.path.splitext(os.path.basename(row[0]))[0]
            emotion_mapping[filename] = row[3]
    return emotion_mapping

if __name__ == '__main__':
    pretty_midi.pretty_midi.MAX_TICK = 1e16

    emotion_mapping = load_emotion_mapping('../music_dataset/YM2413-MDB-v1.0.2/emotion_annotation/verified_annotation.csv')
    midi_data_directory = '../music_dataset/YM2413-MDB-v1.0.2/midi/adjust_tempo_remove_delayed_inst'
    output_directory = './1_output'

    for midi_filename in os.listdir(midi_data_directory):
        if midi_filename.endswith('.mid'):
            generate_nesmdb_midi_examples(
                os.path.join(midi_data_directory, midi_filename),
                output_directory,
                emotion_mapping=emotion_mapping)

## Get Total Number of Files After MIDI-NES Conversion

In [None]:
# midi_files_directory = pathlib.Path('./1_output')
# midi_file_paths = list(midi_files_directory.rglob('*.mid*'))
# print('Number of files (after permutations):', len(midi_file_paths))

## Inspect MIDI Files After MIDI-NES Conversion

In [None]:
# import os
# import csv
# import pretty_midi

# pretty_midi.pretty_midi.MAX_TICK = 1e16

# # Initialize list to store inspection results
# inspection_results = []

# for midi_file_path in midi_file_paths:
#     # Convert Path object to string
#     midi_file_path_str = str(midi_file_path)
    
#     # Process MIDI file
#     midi_data = pretty_midi.PrettyMIDI(midi_file_path_str)
#     num_instruments = len(midi_data.instruments)
#     instrument_names = [pretty_midi.program_to_instrument_name(inst.program) for inst in midi_data.instruments]

#     # Store inspection results
#     inspection_results.append({
#         "Filename": os.path.basename(midi_file_path_str),
#         "Number of Instruments": num_instruments,
#         "Instrument Names": instrument_names
#     })

# # Write inspection results to CSV file
# csv_output_file = './converted_inspect.csv'
# csv_columns = ["Filename", "Number of Instruments", "Instrument Names"]

# with open(csv_output_file, "w", newline="") as csvfile:
#     writer = csv.DictWriter(csvfile, fieldnames=csv_columns)
#     writer.writeheader()
#     for result in inspection_results:
#         writer.writerow(result)

# print("Inspection results saved to", csv_output_file)

## 2. Extract Musical Information From Formatted MIDI Files

### Extract Musical Features

In [None]:
import pretty_midi
import numpy as np
import os
import json
import re

instrument_vocab = set()

def round_features(features, precision=4):
    # Round all float values in the feature dictionaries to a specified precision
    rounded_features = []
    for feature in features:
        rounded_features.append({
            'pitch': feature['pitch'],
            'velocity': feature['velocity'],
            'duration': round(feature['duration'], precision),
            'tempo': round(feature.get('tempo', 0), precision)  # Use get to provide a default value
        })
    return rounded_features

def extract_midi_features(midi_file, instrument_vocab, precision=4):
    # print(f'Loading MIDI file: {midi_file}')
    midi_data = pretty_midi.PrettyMIDI(midi_file)
    
    instrument_features = {}
    for instrument in midi_data.instruments:
        if not instrument.is_drum:
            instrument_name = pretty_midi.program_to_instrument_name(instrument.program)
            # print(f'Processing instrument: {instrument_name}')
            instrument_vocab.add(instrument_name)  # Update the set with new instrument
            
            notes_data = []
            last_end_time = 0
            for note in instrument.notes:
                # Handle silence between notes
                if note.start > last_end_time:
                    silence_duration = note.start - last_end_time
                    silence_features = {
                        'pitch': 0,
                        'velocity': 0,
                        'duration': silence_duration,
                        'tempo': 0
                    }
                    notes_data.append(silence_features)

                note_features = {
                    'pitch': note.pitch,
                    'velocity': note.velocity,
                    'duration': note.end - note.start,
                    'tempo': 0  # Initialize tempo with 0
                }
                notes_data.append(note_features)
                last_end_time = note.end

            # Compute tempos for notes
            tempo_changes = midi_data.get_tempo_changes()
            tempos = np.interp([note.start for note in instrument.notes], tempo_changes[0], tempo_changes[1])
            for note_data, tempo in zip([nd for nd in notes_data if nd['pitch'] != 0], tempos):
                note_data['tempo'] = tempo
            
            instrument_features[instrument_name] = round_features(notes_data, precision)
    return instrument_features, instrument_vocab

def process_midi_dataset(dataset_path, output_path, process_path, precision=4):
    instrument_vocab = set()

    for mood_folder in os.listdir(dataset_path):
        mood_path = os.path.join(dataset_path, mood_folder)
        if os.path.isdir(mood_path):
            mood_label = os.path.basename(mood_path)
            mood_label_clean = re.sub(r'^Q\d+_', '', mood_label)
            print(f'Processing mood: {mood_label_clean}')
            
            for midi_file in os.listdir(mood_path):
                if midi_file.endswith('.mid'):
                    midi_path = os.path.join(mood_path, midi_file)
                    # print(f'Processing file: {midi_path}')
                    instrument_features, instrument_vocab = extract_midi_features(midi_path, instrument_vocab, precision)

                    if not instrument_features or all(len(v) == 0 for v in instrument_features.values()):
                        print(f"No valid features found in {midi_file}, skipping.")
                        continue

                    track_id = os.path.splitext(midi_file)[0]
                    features = {
                        'track_id': track_id,
                        'mood': mood_label_clean,
                        'instruments': instrument_features
                    }
                    
                    output_subpath = os.path.join(output_path, mood_label_clean)
                    os.makedirs(output_subpath, exist_ok=True)
                    json_file_name = f'{track_id}.json'
                    output_file_path = os.path.join(output_subpath, json_file_name)
                    
                    with open(output_file_path, 'w') as json_file:
                        json.dump(features, json_file, indent=4)
                    
                    # print(f'Saved features to: {output_file_path}')
    
    os.makedirs(process_path, exist_ok=True)

    # Save the instrument vocabulary to a JSON file
    instrument_dict = {instrument: idx for idx, instrument in enumerate(sorted(instrument_vocab))}
    with open(os.path.join(process_path, 'instrument_vocab.json'), 'w') as vocab_file:
        json.dump(instrument_dict, vocab_file, indent=4)

if __name__ == '__main__':
    dataset_path = './1_output'
    output_path = './2_output_features'
    process_path = './3_processed_features'

    process_midi_dataset(dataset_path, output_path, process_path)

### Transform Extracted Musical Features

In [None]:
import json
import os
from collections import defaultdict

def load_instrument_vocab(vocab_path):
    """Load the instrument vocabulary from a JSON file."""
    with open(vocab_path, 'r') as file:
        return json.load(file)

def load_instrument_specific_vocab(dataset_path):
    instrument_event_vocab = defaultdict(lambda: {"<PAD>": 0})  # Nested dictionary for each instrument
    instrument_event_id = defaultdict(lambda: 1)  # Separate ID counters for each instrument
    for mood in os.listdir(dataset_path):
        mood_path = os.path.join(dataset_path, mood)
        if os.path.isdir(mood_path):
            print(f'Processing mood: {mood}')
            for json_file in os.listdir(mood_path):
                if json_file.endswith('.json'):
                    json_path = os.path.join(mood_path, json_file)
                    with open(json_path, 'r') as file:
                        data = json.load(file)
                        instruments = data.get('instruments', {})
                        for instrument, events in instruments.items():
                            for event in events:
                                event_key = tuple(sorted(event.items()))  # Create a hashable event key
                                if event_key not in instrument_event_vocab[instrument]:
                                    instrument_event_vocab[instrument][event_key] = instrument_event_id[instrument]
                                    instrument_event_id[instrument] += 1
    return instrument_event_vocab

def save_vocab(vocab, output_folder):
    # Ensure the output directory exists
    os.makedirs(output_folder, exist_ok=True)
    
    # Iterate over each instrument's vocabulary and save to separate files
    for instrument, events in vocab.items():
        vocab_dict = {str(key): value for key, value in events.items()}
        output_path = os.path.join(output_folder, f"{instrument}_vocab.json")
        with open(output_path, 'w') as file:
            json.dump(vocab_dict, file, indent=4)
        print(f'Vocabulary for {instrument} saved to {output_path}')

def load_all_vocabs(vocab_folder):
    """Load all vocabularies from the specified folder into a dictionary."""
    vocabs = {}
    for vocab_file in os.listdir(vocab_folder):
        if vocab_file.endswith('_vocab.json'):
            instrument = vocab_file.replace('_vocab.json', '')
            with open(os.path.join(vocab_folder, vocab_file), 'r') as file:
                vocabs[instrument] = json.load(file)
    
    # Convert keys from string to tuple, skip <PAD> token
    processed_vocabs = {}
    for instr, vocab in vocabs.items():
        vocab_dict = {}
        for key, value in vocab.items():
            if key != "<PAD>":  # Skip the PAD token
                try:
                    # Try to convert the string representation of the tuple to an actual tuple
                    tuple_key = tuple(eval(key))
                    vocab_dict[tuple_key] = value
                except SyntaxError:
                    # If there is a syntax error in eval, log or ignore
                    print(f"Error evaluating key {key}: Skipping")
        processed_vocabs[instr] = vocab_dict
    return processed_vocabs

def transform_dataset(dataset_path, vocabs, instrument_vocab, output_folder):
    num_instruments = len(instrument_vocab)
    os.makedirs(output_folder, exist_ok=True)
    for mood in os.listdir(dataset_path):
        mood_path = os.path.join(dataset_path, mood)
        if os.path.isdir(mood_path):
            output_mood_path = os.path.join(output_folder, mood)
            os.makedirs(output_mood_path, exist_ok=True)
            for json_file in os.listdir(mood_path):
                if json_file.endswith('.json'):
                    json_path = os.path.join(mood_path, json_file)
                    with open(json_path, 'r') as file:
                        data = json.load(file)
                        transformed_instruments = {}
                        instrument_vector = [0] * num_instruments
                        for instrument, events in data['instruments'].items():
                            if instrument in vocabs:  # Check if there's a vocab for this instrument
                                vocab = vocabs[instrument]
                                transformed_events = [vocab.get(tuple(sorted(event.items())), -1) for event in events]  # Get event ID or -1 if not found
                                transformed_instruments[instrument] = transformed_events
                        transformed_data = {
                            'mood': data['mood'],
                            'instrument_vector': instrument_vector,
                            'instruments': transformed_instruments
                        }
                        output_file_path = os.path.join(output_mood_path, json_file)
                        with open(output_file_path, 'w') as out_file:
                            json.dump(transformed_data, out_file, indent=4)

if __name__ == '__main__':
    dataset_path = './2_output_features'
    vocab_output_folder = './3_processed_features/instrument_vocabs'
    processed_folder = './3_processed_features/data'
    instrument_vocab_path = './3_processed_features/instrument_vocab.json'
    
    # Load the data and build the vocabulary
    instrument_vocab = load_instrument_specific_vocab(dataset_path)
    instru_vocab = load_instrument_vocab(instrument_vocab_path)
    save_vocab(instrument_vocab, vocab_output_folder)
    vocabs = load_all_vocabs(vocab_output_folder)
    
    # Transform the dataset using the loaded vocabularies
    transform_dataset(dataset_path, vocabs, instru_vocab, processed_folder)

## Model Training

### 1. Load and Prepare Data

In [None]:
import json
import os
import numpy as np
from tensorflow.keras.preprocessing.sequence import pad_sequences

data_path = './3_processed_features/data'

def load_data(data_path):
    sequences = []
    labels = []
    mood_labels = {'angry': 0, 'happy': 1, 'relaxed': 2, 'sad': 3}
    
    for mood in os.listdir(data_path):
        mood_path = os.path.join(data_path, mood)
        for file in os.listdir(mood_path):
            file_path = os.path.join(mood_path, file)
            with open(file_path, 'r') as json_file:
                data = json.load(json_file)
                for instrument, events in data['instruments'].items():
                    if events:  # ensure there are events
                        sequences.append(events)
                        labels.append(mood_labels[mood])
    
    # Convert lists to numpy arrays
    sequences = pad_sequences(sequences, padding='post')
    labels = np.array(labels)
    return sequences, labels

sequences, labels = load_data(data_path)
print("Loaded sequences:", sequences.shape)
print("Loaded labels:", labels.shape)

### 2. Build Model

In [None]:
import os
import json

vocab_directory = './3_processed_features/instrument_vocabs'
max_vocab_id = 0

# Iterate through each vocab file and find the highest id
for filename in os.listdir(vocab_directory):
    filepath = os.path.join(vocab_directory, filename)
    with open(filepath, 'r') as file:
        vocab = json.load(file)
        max_id = max(map(int, vocab.values()))  # Convert values to integers and find the max
        max_vocab_id = max(max_vocab_id, max_id)

# Since vocab indices are typically 0-based, add 1 to get the correct size
max_vocab_size = max_vocab_id + 1

print("The maximum vocabulary size is:", max_vocab_size)

In [None]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Embedding, Dense, Dropout

# Hyperparameters
vocab_size = max_vocab_size 
embedding_dim = 64
lstm_units = 128
num_classes = 4  # Number of mood categories
dropout_rate = 0.4

# Define the model
def build_model(vocab_size, embedding_dim, lstm_units, num_classes, dropout_rate):
    model = Sequential([
        Embedding(input_dim=vocab_size, output_dim=embedding_dim, mask_zero=True),
        LSTM(lstm_units, return_sequences=True),
        Dropout(dropout_rate),
        LSTM(lstm_units),
        Dropout(dropout_rate),
        Dense(num_classes, activation='softmax')
    ])
    
    # Compile the model
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

model = build_model(vocab_size, embedding_dim, lstm_units, num_classes, dropout_rate)
model.summary()

### 3. Train Model

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

model_dir = './trained_models'
os.makedirs(model_dir, exist_ok=True)

# Callbacks for monitoring and improving training
checkpoint = ModelCheckpoint(
    os.path.join(model_dir, 'model_epoch{epoch:02d}_loss{val_loss:.2f}.keras'), 
    save_best_only=False, 
    monitor='val_loss', 
    mode='min', 
    verbose=1
)

# early_stopping = EarlyStopping(monitor='val_loss', patience=5, verbose=1)

# Train the model
history = model.fit(
    sequences, labels,
    epochs=3,
    batch_size=128,
    validation_split=0.2,  # Use 20% of the data for validation
    callbacks=[checkpoint]
)