# Data Preparation

In [6]:
import tensorflow as tf
from tensorflow.keras.preprocessing import image
import numpy as np
import os
import csv

# Adapt our files to be the right size for the model
def preprocess_image(image_path):
    img = tf.io.read_file(image_path)
    img = tf.image.decode_png(img, channels=1)
    img = tf.image.resize(img, (528, 528))  # Resize the image
    img = tf.cast(img, tf.float32) / 127.5 - 1.0  # Normalize to [-1, 1] range, similar to RGB preprocessing
    img = tf.image.grayscale_to_rgb(img)  # Convert grayscale to RGB by duplicating the channel
    return img



import os
import csv
import tensorflow as tf

def generate_datasets(image_folder, label_file, batch_size=32):
    pitch_image_paths = []
    pitch_labels = []
    duration_image_paths = []
    duration_labels = []

    with open(label_file, 'r') as csvfile:
        reader = csv.reader(csvfile)
        previous_filename = None
        for row in reader:
            filename = row[0]
            labels = [int(x) for x in row[1][1:-1].split(', ')]

            # If the filename matches the previous one, it's a duration label
            if filename == previous_filename:
                duration_image_paths.append(os.path.join(image_folder, filename))
                duration_labels.append(labels)
            else:
                pitch_image_paths.append(os.path.join(image_folder, filename))
                pitch_labels.append(labels)

            previous_filename = filename

    # Creating the datasets
    pitch_dataset = tf.data.Dataset.from_tensor_slices((pitch_image_paths, pitch_labels))
    duration_dataset = tf.data.Dataset.from_tensor_slices((duration_image_paths, duration_labels))

    # Applying preprocessing to images
    pitch_dataset = pitch_dataset.map(lambda x, y: (preprocess_image(x), y))
    duration_dataset = duration_dataset.map(lambda x, y: (preprocess_image(x), y))

    # Batching and prefetching
    pitch_dataset = pitch_dataset.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    duration_dataset = duration_dataset.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)

    return pitch_dataset, duration_dataset


# Model Architecture

In [5]:
from tensorflow.keras.applications import EfficientNetB6
from tensorflow.keras.layers import Reshape, Input
from tensorflow.keras.models import Model

def get_cnn_model():
    base_model = EfficientNetB6(
        input_shape=(528, 528, 3),
        include_top=False,
        weights="imagenet"
    )
    base_model.trainable = False  # Freeze the model to use it as a feature extractor

    # Reshape the output to prepare it for the Transformer Encoder
    base_model_out = base_model.output
    base_model_out = Reshape((-1, base_model_out.shape[-1]))(base_model_out)

    cnn_model = Model(base_model.input, base_model_out)
    return cnn_model


In [None]:
cnn_model = get_cnn_model()

In [None]:
from tensorflow.keras.layers import Layer, MultiHeadAttention, Dense, LayerNormalization, Embedding, Dropout
import tensorflow as tf

class TransformerEncoderBlock(Layer):
    def __init__(self, embed_dim, dense_dim, num_heads):
        super().__init__()
        self.attention = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.dense_proj = tf.keras.Sequential([
            Dense(dense_dim, activation="relu"),
            Dense(embed_dim),

        ])
        self.layernorm_1 = LayerNormalization()
        self.layernorm_2 = LayerNormalization()

    def call(self, inputs):
        attention_output = self.attention(inputs, inputs)
        proj_input = self.layernorm_1(inputs + attention_output)
        proj_output = self.dense_proj(proj_input)
        return self.layernorm_2(proj_input + proj_output)

class TransformerDecoderBlock(Layer):
    def __init__(self, embed_dim, ff_dim, num_heads, vocab_size):
        super().__init__()
        self.attention_1 = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.attention_2 = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = tf.keras.Sequential([
            Dense(ff_dim, activation="relu"),
            Dense(embed_dim),
        ])
        self.layernorm_1 = LayerNormalization()
        self.layernorm_2 = LayerNormalization()
        self.layernorm_3 = LayerNormalization()
        self.embedding = Embedding(vocab_size, embed_dim)
        self.dropout_1 = Dropout(0.3)
        self.dropout_2 = Dropout(0.5)
        self.out = Dense(vocab_size, activation="softmax")

    def call(self, inputs, encoder_outputs, training, mask=None):
        inputs = self.embedding(inputs)
        attention_output_1 = self.attention_1(query=inputs, value=inputs, key=inputs)
        out_1 = self.layernorm_1(inputs + attention_output_1)
        attention_output_2 = self.attention_2(query=out_1, value=encoder_outputs, key=encoder_outputs)
        out_2 = self.layernorm_2(out_1 + attention_output_2)
        ffn_output = self.ffn(out_2)
        ffn_output = self.dropout_1(ffn_output, training=training)
        return self.out(ffn_output)

class MusicGenerationModel(tf.keras.Model):
    def __init__(self, cnn_model):
        super().__init__()
        self.cnn_model = cnn_model
        self.encoder = TransformerEncoderBlock(embed_dim=128, dense_dim=512, num_heads=8)
        self.decoder = TransformerDecoderBlock(embed_dim=128, ff_dim=512, num_heads=8, vocab_size=22)

    def call(self, image, target):
        cnn_features = self.cnn_model(image)
        encoded_features = self.encoder(cnn_features)
        output = self.decoder(target, encoded_features)
        return output


In [None]:
model = MusicGenerationModel(cnn_model)

# Model Training

In [None]:
# Load the datasets
pitch_dataset, duration_dataset = generate_datasets('../raw_data/sheet_images', '../raw_data/labels.csv')

# Define the model architecture (same as before)
pitch_model = MusicGenerationModel(cnn_model)
duration_model = MusicGenerationModel(cnn_model)

# Compile the models
pitch_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
duration_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

# Train the models separately
pitch_model.fit(pitch_dataset, epochs=25, validation_data=None)  # Add validation data if available
duration_model.fit(duration_dataset, epochs=25, validation_data=None)  # Add validation data if available


# Extract Features from CNN 

In [None]:
pitch_feature_extractor = tf.keras.Model(inputs=pitch_model.input, outputs=pitch_model.cnn_model.output)
duration_feature_extractor = tf.keras.Model(inputs=duration_model.input, outputs=duration_model.cnn_model.output)

# Concatenate the outputs along the last dimension
def combine_features(pitch_features, duration_features):
    combined_features = tf.concat([pitch_features, duration_features], axis=-1)
    return combined_features

# Assuming `transformer_encoder` is your Transformer encoder model
combined_features = combine_features(pitch_feature_extractor(image_input), duration_feature_extractor(image_input))
encoded_output = transformer_encoder(combined_features)


In [None]:
# Attempt to integrate the entire model

class CompleteMusicGenerationModel(tf.keras.Model):
    def __init__(self, pitch_feature_extractor, duration_feature_extractor, transformer_encoder, transformer_decoder):
        super().__init__()
        self.pitch_feature_extractor = pitch_feature_extractor
        self.duration_feature_extractor = duration_feature_extractor
        self.transformer_encoder = transformer_encoder
        self.transformer_decoder = transformer_decoder

    def call(self, image, target_sequence):
        pitch_features = self.pitch_feature_extractor(image)
        duration_features = self.duration_feature_extractor(image)
        combined_features = combine_features(pitch_features, duration_features)
        encoded_output = self.transformer_encoder(combined_features)
        decoded_output = self.transformer_decoder(target_sequence, encoded_output)
        return decoded_output

# Initialize and compile the complete model
complete_model = CompleteMusicGenerationModel(
    pitch_feature_extractor=pitch_feature_extractor,
    duration_feature_extractor=duration_feature_extractor,
    transformer_encoder=transformer_encoder,
    transformer_decoder=transformer_decoder
)

complete_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')


In [None]:
# Assuming `train_dataset` has both image inputs and target sequences
complete_model.fit(train_dataset, epochs=25, validation_data=validation_dataset)


# Convert to MusicXML

In [None]:
from music21 import stream, note

def predict_sequence(image):
    sequence = model.predict(image)
    return decode_sequence(sequence)  # Implement a function to decode sequences

def convert_to_musicxml(sequence):
    s = stream.Stream()
    for sym in sequence:
        if sym == 'C4 Quarter':
            n = note.Note('C4', quarterLength=1.0)
            s.append(n)
        # Handle other symbols similarly
    s.write('musicxml', fp='output.xml')

# Inference
new_image = load_image('path_to_image')  # Load a new image
predicted_sequence = predict_sequence(new_image)
convert_to_musicxml(predicted_sequence)
