# Custom Training Code <br>
Built by Alex Fisher and Kevin Parra-Olmedo

## Import Dependencies

In [None]:
import tensorflow as tf
# activate GPU
physical_devices = tf.config.list_physical_devices('GPU')
print(physical_devices)
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
tf.debugging.set_log_device_placement(False)
tf.config.set_soft_device_placement(True)

In [None]:
import numpy as np
import pretty_midi
import librosa


from basic_pitch import inference
from basic_pitch import models

from basic_pitch.constants import (
    ANNOT_N_FRAMES,
    ANNOTATIONS_FPS,
    ANNOTATIONS_N_SEMITONES,
    AUDIO_N_SAMPLES,
    N_FREQ_BINS_CONTOURS,
    AUDIO_SAMPLE_RATE,
    FFT_HOP
)

BATCH_SIZE = 16
SPLIT_INTERVAL = 2
DATASET_PERCENTAGE = 1

tfkl = tf.keras.layers

## Load in sample dataset files<br>
We are using a small sample from MAESTRO dataset's 100 GB of midi/wav files

In [None]:
import json

# Load data from JSON file
with open('./datasets/maestro-v3.0.0/maestro-v3.0.0.json', 'r') as f:
    data = json.load(f)

print("Number of samples:", len(data['midi_filename']))

audio_midi_pairs = []
for i in range(0, len(data['midi_filename'])):
    audio_filename = './datasets/maestro-v3.0.0/' + data['audio_filename'][f"{i}"]
    midi_filename = './datasets/maestro-v3.0.0/' + data['midi_filename'][f"{i}"]
    audio_midi_pairs.append((audio_filename, midi_filename))

audio_midi_pairs = audio_midi_pairs[:int(len(audio_midi_pairs) * 0.01)]
print("Number of samples used: " + str(len(audio_midi_pairs)))
print(audio_midi_pairs)

# Preprocess audio and MIDI pair files
Audio needs to fit what model takes as input (windowed audio, uses Basic Pitch's inference get_audio_input function)<br>
MIDI needs to match what model outputs (binary matrix)

In [None]:
# FUNCTION TO PREPROCESS MIDI TO BINARY CLASSIFICATION NOTE ONSET MATRIX
def midi_to_piano_onset_matrix(midi_path, frames_per_second=ANNOTATIONS_FPS):
    """
    Convert MIDI file to a binary matrix representing onset of piano keys using a set FPS.

    Parameters:
    - midi_path (str): Path to the MIDI file.
    - frames_per_second (int): Number of frames per second for the binary representation.

    Returns:
    - numpy.ndarray: Binary matrix where rows represent the 88 piano keys and columns are time frames.
    """

    # Load the MIDI file
    midi_data = pretty_midi.PrettyMIDI(midi_path)

    # Duration of the MIDI file in seconds
    duration = midi_data.get_end_time()

    # 88 keys for standard piano
    num_piano_keys = 88

    # Calculate the total number of frames based on the FPS
    total_frames = int(duration * frames_per_second)

    # Initialize binary matrix with zeros
    binary_matrix = np.zeros((total_frames, num_piano_keys))

    for instrument in midi_data.instruments:
        for note in instrument.notes:
            # Only consider valid piano notes (from 21 to 108)
            if 21 <= note.pitch <= 108:
                # Find the frame for this onset time
                onset_frame = int(note.start * frames_per_second)

                # Prevent indexing beyond the matrix size
                if onset_frame < total_frames:
                    # Adjust the pitch value to fit within our matrix's row indices (0-87)
                    adjusted_pitch = note.pitch - 21

                    # Mark the onset in the binary matrix
                    binary_matrix[onset_frame, adjusted_pitch] = 1

    return binary_matrix


The following three data processing methods produce functional processed data for model training. However, only method 3 is considered scalable for training on large datasets.

In [None]:
# METHOD 1: PREPROCESS ENTIRE DATASET AND STORE IN MEMORY
# THIS USES THE MOST MEMORY, UNSCALABLE SOLUTION

os_x = []
os_y = []


for idx, (audio_filename, midi_filename) in enumerate(audio_midi_pairs[:int(len(audio_midi_pairs)*0.2)]):
    offset = 0
    # preprocess midi
    onsets = midi_to_piano_onset_matrix(midi_filename, frames_per_second=ANNOTATIONS_FPS)
    while offset < librosa.get_duration(filename=audio_filename) - SPLIT_INTERVAL:
        # preprocess audio

        n_overlapping_frames = 30
        overlap_len = n_overlapping_frames * FFT_HOP
        hop_size = AUDIO_N_SAMPLES - overlap_len

        # modified get_input_audio function to get audio from offset
        assert overlap_len % 2 == 0, "overlap_length must be even, got {}".format(overlap_len)
        audio_original, _ = librosa.load(audio_filename, sr=AUDIO_SAMPLE_RATE, offset=offset, duration=SPLIT_INTERVAL, mono=True)

        original_length = audio_original.shape[0]
        audio_original = np.concatenate([np.zeros((int(overlap_len / 2),), dtype=np.float32), audio_original])
        audio_windowed, window_times = inference.window_audio_file(audio_original, hop_size)
    
        os_x.append(audio_windowed)

        split_onsets = onsets[int(offset*ANNOTATIONS_FPS):int((offset+SPLIT_INTERVAL)*ANNOTATIONS_FPS), :]
        if (split_onsets.shape[0] < ANNOTATIONS_FPS * SPLIT_INTERVAL):
            padding = ANNOTATIONS_FPS * SPLIT_INTERVAL - split_onsets.shape[0]
            split_onsets = np.pad(split_onsets, [(0, padding), (0, 0)], 'constant')
        os_y.append(split_onsets)

        offset += SPLIT_INTERVAL

tensor_dataset = tf.data.Dataset.from_tensor_slices((os_x, os_y))
train_dataset = tensor_dataset.take(int(len(tensor_dataset)*0.8))
val_dataset = tensor_dataset.skip(int(len(tensor_dataset)*0.8))

print("train_dataset: ", train_dataset)
print("val_dataset: ", val_dataset)

take_count = sum(1 for _ in train_dataset)
print(f"Size of take_dataset: {take_count}")

skip_count = sum(1 for _ in val_dataset)
print(f"Size of skip_dataset: {skip_count}")

ds_count = sum(1 for _ in tensor_dataset)
print(f"Size of batched_dataset: {ds_count}")

print("\n\n\n\n")
for audio, onset in train_dataset.take(1):  # Adjust the number taken as needed
    print("Audio shape:", audio.shape)
    print("Onset shape:", onset.shape)
    # Optionally, visually inspect the actual data
    print("Audio data sample:", audio[0])  # Inspect first sample of the batch
    print("Onset data sample:", onset[0])  # Inspect first sample of the batch

In [None]:
# METHOD 2: PROCESS ENTIRE DATA BEFOREHAND AND SAVE IN TFRECORDS
# THIS IS ALSO UNSCALABLE, SAVES ENTIRE PROCESSED DATASET IN PERSISTENT MEMORY AS A BINARY FILE

def serialize_example(audio_windowed, split_onsets):
    # Flatten the tensors using TensorFlow's reshape function
    audio_windowed_flat = tf.reshape(audio_windowed, [-1])
    split_onsets_flat = tf.reshape(split_onsets, [-1])

    # Create a feature dictionary
    feature = {
        'audio_windowed': tf.train.Feature(float_list=tf.train.FloatList(value=audio_windowed_flat.numpy())),
        'split_onsets': tf.train.Feature(float_list=tf.train.FloatList(value=split_onsets_flat.numpy()))
    }
    # Create an Example
    example = tf.train.Example(features=tf.train.Features(feature=feature))
    return example.SerializeToString()


with tf.io.TFRecordWriter('train.tfrecords') as writer:
    for idx, (audio_filename, midi_filename) in enumerate(audio_midi_pairs[:int(len(audio_midi_pairs)*1)]):
        offset = 0
        # preprocess midi
        onsets = midi_to_piano_onset_matrix(midi_filename, frames_per_second=ANNOTATIONS_FPS)
        while offset < librosa.get_duration(filename=audio_filename) - SPLIT_INTERVAL:
            # preprocess audio

            n_overlapping_frames = 30
            overlap_len = n_overlapping_frames * FFT_HOP
            hop_size = AUDIO_N_SAMPLES - overlap_len

            # modified get_input_audio function to get audio from offset
            assert overlap_len % 2 == 0, "overlap_length must be even, got {}".format(overlap_len)
            audio_original, _ = librosa.load(audio_filename, sr=AUDIO_SAMPLE_RATE, offset=offset, duration=SPLIT_INTERVAL, mono=True)

            original_length = audio_original.shape[0]
            audio_original = np.concatenate([np.zeros((int(overlap_len / 2),), dtype=np.float32), audio_original])
            audio_windowed, window_times = inference.window_audio_file(audio_original, hop_size)

            split_onsets = onsets[int(offset*ANNOTATIONS_FPS):int((offset+SPLIT_INTERVAL)*ANNOTATIONS_FPS), :]
            if (split_onsets.shape[0] < ANNOTATIONS_FPS * SPLIT_INTERVAL):
                padding = ANNOTATIONS_FPS * SPLIT_INTERVAL - split_onsets.shape[0]
                split_onsets = np.pad(split_onsets, [(0, padding), (0, 0)], 'constant')
            
            # write sample to file
            sample = serialize_example(audio_windowed, split_onsets)
            writer.write(sample)

            offset += SPLIT_INTERVAL


def parse_tfrecord(example_proto, audio_shape, onsets_shape):
    # Define the features to be extracted
    features = {
        'audio_windowed': tf.io.FixedLenFeature([np.prod(audio_shape)], tf.float32),
        'split_onsets': tf.io.FixedLenFeature([np.prod(onsets_shape)], tf.float32),
    }
    parsed_features = tf.io.parse_single_example(example_proto, features)

    # Reshape the data to its original shape
    audio_windowed = tf.reshape(parsed_features['audio_windowed'], audio_shape)
    split_onsets = tf.reshape(parsed_features['split_onsets'], onsets_shape)
    
    return audio_windowed, split_onsets

# Replace these with the actual shapes of your audio_windowed and split_onsets
audio_shape = (2, 43844, 1) 
onsets_shape = (172, 88)  

raw_dataset = tf.data.TFRecordDataset('train.tfrecords')
parsed_dataset = raw_dataset.map(lambda x: parse_tfrecord(x, audio_shape, onsets_shape))

dataset_size = len(list(parsed_dataset))  # Number of items in the dataset

train_size = int(dataset_size * 0.8)
val_size = dataset_size - train_size

train_dataset = parsed_dataset.take(train_size)
val_dataset = parsed_dataset.skip(train_size)

buffer_size = dataset_size  # Set buffer size to the dataset size for complete shuffling

#train_dataset = train_dataset.shuffle(buffer_size, reshuffle_each_iteration=True)
#val_dataset = val_dataset.shuffle(buffer_size, reshuffle_each_iteration=True)


for audio, onset in train_dataset.take(1):  # Adjust the number taken as needed
    print("Audio shape:", audio.shape)
    print("Onset shape:", onset.shape)
    # Optionally, visually inspect the actual data
    print("Audio data sample:", audio[0])  # Inspect first sample of the batch
    print("Onset data sample:", onset[0])  # Inspect first sample of the batch

In [None]:
# METHOD 3: PROCESS DATA ON-THE-FLY USING TENSORFLOW DATA GENERATOR
# THIS METHOD IS SCALABLE, IT WILL PROCESS THE DATASET SEQUENTIALLY AS DATA IS PULLED, NO MATTER HOW LARGE THE DATASET IS

def preprocess_data(audio_filename, midi_filename, AUDIO_SAMPLE_RATE, SPLIT_INTERVAL, ANNOTATIONS_FPS, AUDIO_N_SAMPLES, FFT_HOP):
    # Initialize lists to store the preprocessed data

    n_overlapping_frames = 30
    overlap_len = n_overlapping_frames * FFT_HOP
    hop_size = AUDIO_N_SAMPLES - overlap_len
    offset = 0
    onsets = midi_to_piano_onset_matrix(midi_filename, frames_per_second=ANNOTATIONS_FPS)

    while offset < librosa.get_duration(path=audio_filename) - SPLIT_INTERVAL:
        audio_original, _ = librosa.load(audio_filename, sr=AUDIO_SAMPLE_RATE, offset=offset, duration=SPLIT_INTERVAL, mono=True)

        audio_original = np.concatenate([np.zeros((int(overlap_len / 2),), dtype=np.float32), audio_original])
        audio_windowed, window_times = inference.window_audio_file(audio_original, hop_size)

        split_onsets = onsets[int(offset*ANNOTATIONS_FPS):int((offset+SPLIT_INTERVAL)*ANNOTATIONS_FPS), :]
        if (split_onsets.shape[0] < ANNOTATIONS_FPS * SPLIT_INTERVAL):
            padding = ANNOTATIONS_FPS * SPLIT_INTERVAL - split_onsets.shape[0]
            split_onsets = np.pad(split_onsets, [(0, padding), (0, 0)], 'constant')

        yield np.array(audio_windowed), np.array(split_onsets)
        offset += SPLIT_INTERVAL


# NEED TO INITIALIZE DATA GENERATOR TO SAVE MEMORY (DATA WILL BE LOADED AS THE MODEL NEEDS IT, NOT ALL BEFORE TRAINING)

# create data generator
def data_generator(dataset, AUDIO_SAMPLE_RATE, SPLIT_INTERVAL, ANNOTATIONS_FPS, AUDIO_N_SAMPLES, FFT_HOP):
    print("CALLED DATA GENERATOR")
    for audio_filename, midi_filename in dataset:
        n_overlapping_frames = 30
        overlap_len = n_overlapping_frames * FFT_HOP
        hop_size = AUDIO_N_SAMPLES - overlap_len
        offset = 0
        onsets = midi_to_piano_onset_matrix(midi_filename, frames_per_second=ANNOTATIONS_FPS)

        while offset < librosa.get_duration(filename=audio_filename) - SPLIT_INTERVAL:
            audio_original, _ = librosa.load(audio_filename, sr=AUDIO_SAMPLE_RATE, offset=offset, duration=SPLIT_INTERVAL, mono=True)

            audio_original = np.concatenate([np.zeros((int(overlap_len / 2),), dtype=np.float32), audio_original])
            audio_windowed, window_times = inference.window_audio_file(audio_original, hop_size)

            split_onsets = onsets[int(offset*ANNOTATIONS_FPS):int((offset+SPLIT_INTERVAL)*ANNOTATIONS_FPS), :]
            if (split_onsets.shape[0] < ANNOTATIONS_FPS * SPLIT_INTERVAL):
                padding = ANNOTATIONS_FPS * SPLIT_INTERVAL - split_onsets.shape[0]
                split_onsets = np.pad(split_onsets, [(0, padding), (0, 0)], 'constant')

            yield np.array(audio_windowed), np.array(split_onsets)
            offset += SPLIT_INTERVAL

output_types = (tf.float32, tf.float64)  # Modify as per your data types
output_shapes = (tf.TensorShape([SPLIT_INTERVAL, 43844, 1]), tf.TensorShape([172, 88]))  # Modify as per your data shapes

# Example of splitting the dataset
train_size = int(0.8 * len(audio_midi_pairs))
train_audio_midi_pairs = audio_midi_pairs[:train_size]
val_audio_midi_pairs = audio_midi_pairs[train_size:]

train_dataset = tf.data.Dataset.from_generator(
    lambda: data_generator(train_audio_midi_pairs, AUDIO_SAMPLE_RATE, SPLIT_INTERVAL, ANNOTATIONS_FPS, AUDIO_N_SAMPLES, FFT_HOP),
    output_types=output_types,
    output_shapes=output_shapes
)

val_dataset = tf.data.Dataset.from_generator(
    lambda: data_generator(val_audio_midi_pairs, AUDIO_SAMPLE_RATE, SPLIT_INTERVAL, ANNOTATIONS_FPS, AUDIO_N_SAMPLES, FFT_HOP),
    output_types=output_types,
    output_shapes=output_shapes
)

batch_size = BATCH_SIZE  # Adjust according to your needs


for audio, onset in train_dataset.take(1):  # Adjust the number taken as needed
    print("Audio shape:", audio.shape)
    print("Onset shape:", onset.shape)
    # Optionally, visually inspect the actual data
    print("Audio data sample:", audio[0])  # Inspect first sample of the batch
    print("Onset data sample:", onset[0])  # Inspect first sample of the batch

## Train the model using the preprocessed data

In [None]:
# PRINT OUT ALL TRAINABLE LAYERS OF THE MODEL
# Iterate through the layers and print the layer name and its trainable status
for layer in models.model().layers:
    print(f"Layer: {layer.name}")
    print(f"Trainable: {layer.trainable}")
    for weight in layer.trainable_weights:
        print(f"\tWeight: {weight.name}, Shape: {weight.shape}")

# If you only want to see layers with trainable weights:
print("\nOnly layers with trainable weights:")
for layer in models.model().layers:
    if layer.trainable_weights:
        print(f"Layer: {layer.name}")
        for weight in layer.trainable_weights:
            print(f"\tWeight: {weight.name}, Shape: {weight.shape}")

In [None]:
# CREATE CUSTOM LOSS FUNCTION FOR WEIGHTED BINARY CROSS ENTROPY
class WeightedBinaryCrossEntropy(tf.keras.losses.Loss):
    def __init__(self, pos_weight, neg_weight, from_logits=False, name='weighted_binary_crossentropy'):
        super().__init__(name=name)
        self.pos_weight = pos_weight
        self.neg_weight = neg_weight
        self.from_logits = from_logits

    def call(self, y_true, y_pred):
        if not self.from_logits:
            #print("\ny_pred: ", y_pred)
            original_length = 22050 * SPLIT_INTERVAL
            n_overlapping_frames = 30
            unwrapped_y_pred = self.unwrap_output_custom(y_pred, original_length, n_overlapping_frames)

            # Manually calculate the weighted binary cross-entropy for predictions that aren't logits
            epsilon = tf.keras.backend.epsilon()
            unwrapped_y_pred = tf.clip_by_value(unwrapped_y_pred, epsilon, 1.0 - epsilon)

            y_true = tf.cast(y_true, tf.float32)
            unwrapped_y_pred = tf.cast(unwrapped_y_pred, tf.float32)
            pos_weight = tf.cast(self.pos_weight, tf.float32)
            neg_weight = tf.cast(self.neg_weight, tf.float32)

            loss = -y_true * tf.math.log(unwrapped_y_pred) * pos_weight - (1.0 - y_true) * tf.math.log(1.0 - unwrapped_y_pred) * neg_weight
        else:
            # Use TensorFlow's built-in function for logits
            loss = tf.nn.weighted_cross_entropy_with_logits(labels=y_true, logits=y_pred, pos_weight=self.pos_weight)

        return tf.reduce_mean(loss)
    
    # custom unwrap output function that remains compatible with TensorFlow's graph execution
    def unwrap_output_custom(self, output: tf.Tensor, audio_original_length: int, n_overlapping_frames: int) -> tf.Tensor:
        """Unwrap batched model predictions to a single matrix.

        Args:
            output: tensor (n_batches, n_times_short, n_freqs)
            audio_original_length: length of original audio signal (in samples)
            n_overlapping_frames: number of overlapping frames in the output

        Returns:
            tensor (n_times, n_freqs)
        """
        output_rank = tf.rank(output)
        #print("output_rank: ", output_rank)
        
        def process_output():
            n_olap = int(0.5 * n_overlapping_frames)
            if n_olap > 0:
                output_processed = output[:, n_olap:-n_olap, :]
            else:
                output_processed = output
                
            output_shape = tf.shape(output_processed)
            n_output_frames_original = tf.cast(tf.floor(audio_original_length * (ANNOTATIONS_FPS / AUDIO_SAMPLE_RATE)), tf.int32)
            unwrapped_output = tf.reshape(output_processed, [output_shape[0] * output_shape[1], output_shape[2]])
            return unwrapped_output[:n_output_frames_original, :]  # trim to original audio length
        
        def handle_invalid_rank():
            # Print a warning message and return a dummy tensor
            tf.print(f"Warning: Expected output rank to be 3, got {output_rank}")
            return tf.zeros((0, 0), dtype=output.dtype)

        return tf.cond(tf.equal(output_rank, 3), process_output, handle_invalid_rank)

## Try built-in tensorflow train method

In [None]:
# Initialize the model

#model_train = models.model()
model_train = 
adam_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
onset_loss_function = WeightedBinaryCrossEntropy(pos_weight=0.99, neg_weight=0.01)
contour_loss_function = WeightedBinaryCrossEntropy(pos_weight=0.99, neg_weight=0.01)
note_loss_function = WeightedBinaryCrossEntropy(pos_weight=0.99, neg_weight=0.01)
model_train.compile(optimizer=adam_optimizer, loss={"onset": onset_loss_function, "note": note_loss_function})

# train model
num_epochs = 1

model_train.fit(train_dataset, validation_data=val_dataset, epochs=num_epochs, batch_size=16)

In [None]:
# Save our trained version of the model

model_train.save('saved_models/dec04_train_99posweight')

## Test and Evaluate Trained Models

In [None]:
# command line execution to output resulting midi from our trained model
!python ../basic_pitch_original/basic_pitch/predict.py --model_path "saved_models/dec04_train_99posweight_onthefly" "model_predictions/our_model/dec04_train_99posweight_onthefly/" "model_predictions/_test_audio/MIDI-Unprocessed_Recital1-3_MID--AUDIO_03_R1_2018_wav--5.wav"

# command line execution to output resulting midi from spotify model for comparison
!python ../basic_pitch_original/basic_pitch/predict.py "model_predictions/spotify_model/" "model_predictions/_test_audio/MIDI-Unprocessed_Recital1-3_MID--AUDIO_03_R1_2018_wav--5.wav"

### Quantitative Evaluation
##### The following code is incomplete. Model Evaluation was done primarily through qualitative analysis (listening to the MIDI by ear)

In [None]:
# FUNCTION TO EVALUATE MODEL
import mir_eval

# Define the evaluation function
def evaluate_model(data, model, threshold=0.5):
    # Lists to hold ground truth and predictions for evaluation
    reference_notes = []
    estimated_notes = []

    # Iterate over your dataset
    for audio_file, midi_file in data:
        # Predict the piano roll with your model
        _, y_pred_midi = inference.predict(audio, model)
        print("MODEL OUTPUT:\n", "\n\nData:\n", y_pred_midi)

    # Compute metrics using mir_eval
    scores = {
        'F-measure': [],
        'F-measure-no-offset': [],
        'Frame-level Accuracy': []
    }
    for ref, est in zip(reference_notes, estimated_notes):
        # mir_eval requires specific formats for reference and estimated notes
        ref_intervals, ref_pitches = mir_eval.util.piano_roll_to_intervals(ref)
        est_intervals, est_pitches = mir_eval.util.piano_roll_to_intervals(est)

        # Evaluate
        p, r, f_measure, _ = mir_eval.transcription.precision_recall_f1_overlap(ref_intervals, ref_pitches, est_intervals, est_pitches)
        scores['F-measure'].append(f_measure)

        # Compute F_no
        f_no = mir_eval.transcription.f_measure_without_offset(ref_intervals, ref_pitches, est_intervals, est_pitches)
        scores['F-measure-no-offset'].append(f_no)

        # Compute frame-level accuracy
        acc = mir_eval.transcription_accuracy(ref_intervals, ref_pitches, est_intervals, est_pitches)
        scores['Frame-level Accuracy'].append(acc)

    # Average the scores
    for key in scores:
        scores[key] = np.mean(scores[key])

    return scores

In [None]:
# PREPARE EVALUATION DATA
audio_filename = "model_predictions/_test_audio/MIDI-Unprocessed_Recital1-3_MID--AUDIO_03_R1_2018_wav--5.wav"
midi_filename = "model_predictions/_test_audio/MIDI-Unprocessed_Recital1-3_MID--AUDIO_03_R1_2018_wav--5.midi"

eval_audio_midi_pairs = [(audio_filename, midi_filename)]


In [None]:
# Evaluate the model
# You need to prepare 'validation_data' in the format that your model expects

scores = evaluate_model(eval_audio_midi_pairs, model_nov26_02)
print(scores)
