#Installations

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
%cd /content/drive/MyDrive/MIDI-VAE-NEW

/content/drive/MyDrive/MIDI-VAE-NEW


In [4]:

# Set the path to your MIDI files
midi_folder_path = "/content/drive/MyDrive/MIDI-VAE-NEW/data"

In [5]:
import os
print(f"Folder exists: {os.path.exists(midi_folder_path)}")
print(f"Files in folder: {os.listdir(midi_folder_path)}")

Folder exists: True
Files in folder: ['Pop', 'Jazz']


In [6]:
# First uninstall existing versions
#!pip uninstall -y numpy tensorflow


In [7]:
#!pip install pretty_midi numpy tensorflow matplotlib

In [8]:
# Install compatible versions
!pip install numpy==1.23.5 tensorflow==2.12.0
!pip install pretty_midi==0.2.9



In [9]:
!pip install tqdm



#Commun : MIDI-VAE Class with update Project environemnt (Python 3.11)

In [10]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

# Add to existing imports
import gc
import warnings


# Disable warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)

In [11]:


class MIDI_VAE(keras.Model):
    def __init__(self, sequence_length, input_dim, latent_dim, intermediate_dim):
        """
        Modern implementation of MIDI-VAE using TensorFlow 2.x

        Args:
            sequence_length: Length of the MIDI sequence
            input_dim: Dimension of input features
            latent_dim: Dimension of the latent space
            intermediate_dim: Dimension of the intermediate layers
        """
        super(MIDI_VAE, self).__init__()

        # Encoder
        self.encoder_lstm = layers.LSTM(intermediate_dim,
                                      return_sequences=True)
        self.encoder_lstm2 = layers.LSTM(intermediate_dim,
                                       return_state=True)

        # VAE layers
        self.dense_mean = layers.Dense(latent_dim)
        self.dense_log_var = layers.Dense(latent_dim)

        # Decoder
        self.decoder_initial = layers.Dense(intermediate_dim,
                                          activation='relu')
        self.decoder_lstm = layers.LSTM(intermediate_dim,
                                      return_sequences=True,
                                      stateful=False)
        self.decoder_dense = layers.Dense(input_dim,
                                        activation='sigmoid')

        # Store parameters
        self.latent_dim = latent_dim
        self.sequence_length = sequence_length

    def encode(self, x):
        # Initial LSTM encoding
        x = self.encoder_lstm(x)

        # Get final states
        _, state_h, state_c = self.encoder_lstm2(x)

        # Generate latent parameters
        z_mean = self.dense_mean(state_h)
        z_log_var = self.dense_log_var(state_h)

        return z_mean, z_log_var

    def reparameterize(self, z_mean, z_log_var):
        batch = tf.shape(z_mean)[0]
        epsilon = tf.random.normal(shape=(batch, self.latent_dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

    def decode(self, z):
        # Initial state from latent vector
        initial_state = self.decoder_initial(z)

        # Repeat vector to create sequence
        x = tf.repeat(tf.expand_dims(initial_state, 1),
                     self.sequence_length, axis=1)

        # Decode sequence
        x = self.decoder_lstm(x)
        reconstruction = self.decoder_dense(x)

        return reconstruction

    def call(self, x):
        # Full forward pass
        z_mean, z_log_var = self.encode(x)
        z = self.reparameterize(z_mean, z_log_var)
        reconstruction = self.decode(z)

        # Add KL divergence loss
        kl_loss = -0.5 * tf.reduce_mean(
            z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1
        )
        self.add_loss(kl_loss)

        return reconstruction

# Training utilities
@tf.function
def train_step(model, x, optimizer):
    """Single training step"""
    with tf.GradientTape() as tape:
        reconstruction = model(x)
        # Reconstruction loss (binary crossentropy)
        reconstruction_loss = tf.reduce_mean(
            tf.reduce_sum(
                keras.losses.binary_crossentropy(x, reconstruction),
                #axis=[1, 2]
                axis=[1] # problematic line corrected
            )
        )
        # Total loss (including KL divergence added in model.call)
        total_loss = reconstruction_loss + tf.reduce_sum(model.losses)

    # Compute and apply gradients
    grads = tape.gradient(total_loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

    return total_loss, reconstruction_loss

# Example usage
def create_and_compile_model(sequence_length=32,
                           input_dim=128,
                           latent_dim=256,
                           intermediate_dim=512):
    """Create and compile the MIDI-VAE model"""
    model = MIDI_VAE(sequence_length, input_dim, latent_dim, intermediate_dim)
    optimizer = keras.optimizers.Adam(learning_rate=1e-3)

    # Build model
    dummy_input = tf.zeros((1, sequence_length, input_dim))
    _ = model(dummy_input)

    return model, optimizer

#MIDI Preprocessing Test Script

In [12]:
import pretty_midi
import tensorflow as tf
import numpy as np
import glob
import os
from collections import defaultdict
import matplotlib.pyplot as plt
from tqdm import tqdm

class MIDIPreprocessor:
    def __init__(self, sequence_length=32,
                 min_pitch=21, max_pitch=109,
                 time_step=0.125):
        self.sequence_length = sequence_length
        self.min_pitch = min_pitch
        self.max_pitch = max_pitch
        self.time_step = time_step
        self.n_pitches = max_pitch - min_pitch + 1

    def load_midi_file(self, file_path):
        try:
            # Use explicit integer types
            pretty_midi.pretty_midi.MAX_TICK = np.int64(1e7)
            midi_data = pretty_midi.PrettyMIDI(file_path)
            return midi_data
        except Exception as e:
            print(f"Error loading {file_path}: {str(e)}")
            return None

    def midi_to_piano_roll(self, midi_data):
        try:
            # Get piano roll with explicit sampling rate
            fs = int(1/self.time_step)
            piano_roll = midi_data.get_piano_roll(fs=fs)

            # Ensure we have enough pitches
            if piano_roll.shape[0] < self.max_pitch + 1:
                pad_size = self.max_pitch + 1 - piano_roll.shape[0]
                piano_roll = np.pad(piano_roll, ((0, pad_size), (0, 0)))

            # Crop to our pitch range
            piano_roll = piano_roll[self.min_pitch:self.max_pitch + 1]

            # Convert to binary (note on/off)
            piano_roll = (piano_roll > 0).astype(np.float32)

            return piano_roll.T
        except Exception as e:
            print(f"Error in piano roll conversion: {str(e)}")
            return None

    def extract_sequences(self, piano_roll):
        if piano_roll is None:
            return np.array([])

        sequences = []
        for i in range(0, len(piano_roll) - self.sequence_length + 1):
            sequence = piano_roll[i:i + self.sequence_length]
            if np.any(sequence):
                sequences.append(sequence)

        if not sequences:
            return np.array([])

        return np.array(sequences)

    def create_tf_dataset(self, midi_files, batch_size=32, shuffle=True):
        all_sequences = []
        processed_files = 0

        for file_path in midi_files:
            try:
                midi_data = self.load_midi_file(file_path)
                if midi_data is None:
                    continue

                piano_roll = self.midi_to_piano_roll(midi_data)
                if piano_roll is None:
                    continue

                sequences = self.extract_sequences(piano_roll)
                if len(sequences) > 0:
                    all_sequences.append(sequences)
                    processed_files += 1

                if processed_files % 10 == 0:
                    print(f"Processed {processed_files} files...")

            except Exception as e:
                print(f"Error processing {file_path}: {str(e)}")
                continue

        if not all_sequences:
            raise ValueError("No valid sequences were extracted from the MIDI files")

        all_sequences = np.concatenate(all_sequences, axis=0)
        print(f"Created dataset with {len(all_sequences)} sequences")

        dataset = tf.data.Dataset.from_tensor_slices(all_sequences)
        if shuffle:
            dataset = dataset.shuffle(10000)

        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)

        return dataset
#-----
def test_midi_preprocessing(midi_folder_path, check_subfolders=True):
    """
    Test MIDI preprocessing pipeline with style folder support

    Args:
        midi_folder_path: Root folder containing MIDIs
        check_subfolders: If True, looks in style subfolders
    """
    print("=== MIDI Preprocessing Test ===")

    # 1. File Discovery
    print("\n1. Checking MIDI files...")
    if check_subfolders:
        # Look in style subfolders
        style_folders = [f for f in os.listdir(midi_folder_path)
                        if os.path.isdir(os.path.join(midi_folder_path, f))]
        midi_files = []
        for style in style_folders:
            style_path = os.path.join(midi_folder_path, style)
            midi_files.extend(glob.glob(os.path.join(style_path, "*.mid")))
            midi_files.extend(glob.glob(os.path.join(style_path, "*.midi")))
    else:
        # Original behavior (root folder only)
        midi_files = glob.glob(os.path.join(midi_folder_path, "*.mid"))
        midi_files.extend(glob.glob(os.path.join(midi_folder_path, "*.midi")))

    if not midi_files:
        print("❌ No MIDI files found!")
        return None, None

    print(f"✓ Found {len(midi_files)} MIDI files")
    print("\nSample files:")
    for f in midi_files[:3]:
        print(f"- {os.path.relpath(f, midi_folder_path)}")

    # 2. Initialize Preprocessor
    print("\n2. Initializing preprocessor...")
    preprocessor = MIDIPreprocessor(
        sequence_length=32,
        min_pitch=21,
        max_pitch=109,
        time_step=0.125
    )

    # 3. Test Processing
    print("\n3. Testing file processing...")
    test_file = midi_files[0]  # Test first file
    print(f"Testing file: {os.path.basename(test_file)}")

    try:
        midi_data = preprocessor.load_midi_file(test_file)
        if midi_data is None:
            print("❌ Failed to load test file")
            return None, None

        piano_roll = preprocessor.midi_to_piano_roll(midi_data)
        if piano_roll is None:
            print("❌ Failed to create piano roll")
            return None, None

        sequences = preprocessor.extract_sequences(piano_roll)
        if len(sequences) == 0:
            print("❌ No sequences extracted")
            return None, None

        print("✓ Successfully processed:")
        print(f"- Duration: {midi_data.get_end_time():.2f} sec")
        print(f"- Extracted {len(sequences)} sequences")
        print(f"- Sequence shape: {sequences[0].shape}")

        return sequences, preprocessor

    except Exception as e:
        print(f"❌ Error processing {test_file}: {str(e)}")
        return None, None
#----

# Run the test
print("Starting MIDI preprocessing test...")
# Test with style subfolders
sequences, preprocessor = test_midi_preprocessing(
    midi_folder_path="/content/drive/MyDrive/MIDI-VAE-NEW/data",
    check_subfolders=True
)

dataset, preprocessor = test_midi_preprocessing(midi_folder_path)

# if dataset is not None:
#     print("\nDataset creation successful!")
#     # Show first batch info
#     for batch in dataset.take(1):
#         print(f"Batch shape: {batch.shape}")
#         print(f"Batch min value: {tf.reduce_min(batch)}")
#         print(f"Batch max value: {tf.reduce_max(batch)}")
# else:
#     print("\nFailed to create dataset. Please check the errors above.")

Starting MIDI preprocessing test...
=== MIDI Preprocessing Test ===

1. Checking MIDI files...
✓ Found 99 MIDI files

Sample files:
- Pop/Bobby_Vinton_-_Sealed_With_a_Kiss.mid
- Pop/Boyzone_-_Fathers_And_Sons.mid
- Pop/Bonnie_Tyler_-_Total_Eclipse_of_the_Heart.mid

2. Initializing preprocessor...

3. Testing file processing...
Testing file: Bobby_Vinton_-_Sealed_With_a_Kiss.mid
✓ Successfully processed:
- Duration: 146.51 sec
- Extracted 1141 sequences
- Sequence shape: (32, 89)
=== MIDI Preprocessing Test ===

1. Checking MIDI files...
✓ Found 99 MIDI files

Sample files:
- Pop/Bobby_Vinton_-_Sealed_With_a_Kiss.mid
- Pop/Boyzone_-_Fathers_And_Sons.mid
- Pop/Bonnie_Tyler_-_Total_Eclipse_of_the_Heart.mid

2. Initializing preprocessor...

3. Testing file processing...
Testing file: Bobby_Vinton_-_Sealed_With_a_Kiss.mid
✓ Successfully processed:
- Duration: 146.51 sec
- Extracted 1141 sequences
- Sequence shape: (32, 89)


# Convert midi to wav

In [13]:
from IPython.display import Audio

# # Convert MIDI to WAV using fluidsynth
# !apt-get install fluidsynth
# # !wget https://download.sf2tool.com/GeneralUser_GS_1.471.zip
# # !unzip GeneralUser_GS_1.471.zip
# !wget https://www.philscomputerlab.com/uploads/3/7/2/3/37231621/weedsgm3.sf2


#V1 only

In [14]:
# midi_file_path = "/content/drive/MyDrive/MIDI-VAE-NEW/generated/generated_music.mid"
# # Convert MIDI to WAV
# wav_file_path = "/content/drive/MyDrive/MIDI-VAE-NEW/generated/generated_music.wav"
# !fluidsynth -ni weedsgm3.sf2 {midi_file_path} -F {wav_file_path} -r 44100

# # The rate argument is added to Audio to specify the sample rate.
# Audio(wav_file_path, rate=44100)

In [15]:
# midi_file_path = "/content/drive/MyDrive/MIDI-VAE-NEW/generated/interpolated_music.mid"
# # Convert MIDI to WAV
# wav_file_path = "/content/drive/MyDrive/MIDI-VAE-NEW/generated/interpolated_music.wav"
# !fluidsynth -ni weedsgm3.sf2 {midi_file_path} -F {wav_file_path} -r 44100

# # The rate argument is added to Audio to specify the sample rate.
# Audio(wav_file_path, rate=44100)

# Style Conditioned MIDI-VAE

## Class definition

In [16]:
# 1. Remove mixed precision code completely (no imports or policy settings)

# 2. Use this simplified StyleConditionedMIDI_VAE class
class StyleConditionedMIDI_VAE(keras.Model):
    def __init__(self, sequence_length, input_dim, latent_dim, intermediate_dim, num_styles):
        super(StyleConditionedMIDI_VAE, self).__init__()

        # Style embedding
        self.style_embedding_dim = 32
        self.style_embedding = layers.Embedding(num_styles, self.style_embedding_dim)

        # Encoder
        self.encoder_lstm = layers.LSTM(intermediate_dim, return_sequences=True)
        self.encoder_lstm2 = layers.LSTM(intermediate_dim, return_state=True)

        # VAE layers
        self.dense_mean = layers.Dense(latent_dim)
        self.dense_log_var = layers.Dense(latent_dim)

        # Decoder
        self.decoder_initial = layers.Dense(intermediate_dim, activation='relu')
        self.decoder_lstm = layers.LSTM(intermediate_dim, return_sequences=True)
        self.decoder_dense = layers.Dense(input_dim, activation='sigmoid')

        # Store parameters
        self.latent_dim = latent_dim
        self.sequence_length = sequence_length
        self.num_styles = num_styles

    def call(self, inputs):
        x, style_ids = inputs
        x = tf.cast(x, tf.float32)  # Ensure float32

        # Encode
        style_embed = self.style_embedding(style_ids)
        x = self.encoder_lstm(x)
        _, state_h, _ = self.encoder_lstm2(x)
        combined = tf.concat([state_h, style_embed], axis=-1)
        z_mean = self.dense_mean(combined)
        z_log_var = self.dense_log_var(combined)

        # Reparameterize
        z = z_mean + tf.exp(0.5 * z_log_var) * tf.random.normal(shape=tf.shape(z_mean))

        # Decode
        style_embed = self.style_embedding(style_ids)
        z_combined = tf.concat([z, style_embed], axis=-1)
        initial_state = self.decoder_initial(z_combined)
        x = tf.repeat(tf.expand_dims(initial_state, 1), self.sequence_length, axis=1)
        x = self.decoder_lstm(x)
        reconstruction = self.decoder_dense(x)

        # KL loss
        kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)
        self.add_loss(kl_loss)

        return reconstruction

# 3. Use this simplified dataset creation
def create_style_conditioned_dataset(datasets, style_ids, batch_size=32):
    """Create dataset with float32 dtype and proper batching"""
    # Concatenate all data
    all_data = np.concatenate([datasets[style] for style in datasets], axis=0)
    all_styles = np.concatenate([style_ids[style] for style in style_ids], axis=0)

    # Create dataset pipeline
    dataset = tf.data.Dataset.from_tensor_slices((all_data, all_styles))
    dataset = dataset.shuffle(10000)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    return dataset

# 4. Simplified training approach
def train_model():
    # 1. Load data
    _, preprocessor = test_midi_preprocessing(midi_folder_path, check_subfolders=True)
    datasets, style_ids = load_style_data(midi_folder_path, {'Jazz':0, 'Mozart':1}, preprocessor)

    # 2. Create dataset with smaller batch size
    dataset = create_style_conditioned_dataset(datasets, style_ids, batch_size=8)  # Reduced batch

    # 3. Create smaller model
    model, optimizer = create_and_compile_model_style_conditioned(
        sequence_length=32,
        input_dim=89,  # From your data
        latent_dim=64,  # Reduced
        intermediate_dim=128,  # Reduced
        num_styles=2
    )

    # 4. Train with memory monitoring
    for epoch in range(5):
        print(f"Epoch {epoch+1}")
        for x, y in dataset:
            with tf.GradientTape() as tape:
                recon = model((x, y))
                loss = tf.reduce_mean(keras.losses.binary_crossentropy(x, recon))
                loss += sum(model.losses)

            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            print(f"Batch loss: {loss.numpy():.4f}")

        # Clear memory
        gc.collect()
        tf.keras.backend.clear_session()

## Style conditioned VAE Trainer

In [17]:
class StyleConditionedVAETrainer:
    def __init__(self, model, checkpoint_dir, optimizer=None):
        self.model = model
        self.checkpoint_dir = checkpoint_dir
        self.optimizer = optimizer or tf.keras.optimizers.Adam(1e-4)

        # Setup checkpoint manager
        self.checkpoint = tf.train.Checkpoint(
            model=self.model,
            optimizer=self.optimizer
        )
        self.manager = tf.train.CheckpointManager(
            self.checkpoint,
            checkpoint_dir,
            max_to_keep=3
        )

        # Restore if available
        if self.manager.latest_checkpoint:
            self.checkpoint.restore(self.manager.latest_checkpoint)
            print(f"Restored from {self.manager.latest_checkpoint}")
        else:
            print("Initializing from scratch")

        # Metrics
        self.train_loss = tf.keras.metrics.Mean(name='train_loss')
        self.reconstruction_loss = tf.keras.metrics.Mean(name='reconstruction_loss')

    @tf.function
    def train_step(self, x, style_ids):
        """Corrected training step - now instance method"""
        with tf.GradientTape() as tape:
            reconstruction = self.model((x, style_ids))
            reconstruction_loss = tf.reduce_mean(
                tf.keras.losses.binary_crossentropy(x, reconstruction)
            )
            total_loss = reconstruction_loss + tf.reduce_sum(self.model.losses)

        grads = tape.gradient(total_loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))

        self.train_loss.update_state(total_loss)
        self.reconstruction_loss.update_state(reconstruction_loss)

        return total_loss, reconstruction_loss

    def train(self, dataset, epochs, save_freq='epoch'):
        for epoch in range(epochs):
            print(f"\nEpoch {epoch+1}/{epochs}")

            # Reset metrics
            self.train_loss.reset_states()
            self.reconstruction_loss.reset_states()

            # Training loop
            for step, (x, style_ids) in enumerate(dataset):
                loss, rec_loss = self.train_step(x, style_ids)

                if step % 500 == 0:
                    print(f"Step {step}: Loss={loss:.4f}, Recon={rec_loss:.4f}")

            # End of epoch
            print(f"Epoch {epoch+1} - Loss: {self.train_loss.result():.4f}")

            if save_freq == 'epoch':
                save_path = self.manager.save()
                print(f"Saved checkpoint: {save_path}")

            # Memory management
            gc.collect()
            tf.keras.backend.clear_session()

##Create style-specific datasets

In [18]:
# Map style names to IDs
style_map = {'Jazz': 0, 'Pop': 1} # Tu peux ajouter 'Classical': 2 si nécessaire

# Fonction complète pour charger et prétraiter les données MIDI avec style IDs
def load_style_data(midi_folder_path, style_map, preprocessor=None, max_files=50):
    if preprocessor is None:
        preprocessor = MIDIPreprocessor()  # Create with default params if none provided
    datasets = {}
    style_ids = {}

    for style_name, style_id in style_map.items():
        style_path = os.path.join(midi_folder_path, style_name)
        if not os.path.exists(style_path):
            continue

        midi_files = [f for f in os.listdir(style_path) if f.endswith(('.mid','.midi'))][:max_files]
        all_sequences = []

        for midi_file in tqdm(midi_files, desc=f"Processing {style_name}"):
            try:
                file_path = os.path.join(style_path, midi_file)
                midi_data = preprocessor.load_midi_file(file_path)
                if midi_data is None:
                    continue

                piano_roll = preprocessor.midi_to_piano_roll(midi_data)
                sequences = preprocessor.extract_sequences(piano_roll)
                if len(sequences) > 0:
                    all_sequences.append(sequences)

            except Exception as e:
                print(f"\nSkipping {midi_file}: {str(e)}")
                continue

        if all_sequences:
            datasets[style_name] = np.concatenate(all_sequences)
            style_ids[style_name] = np.full(len(datasets[style_name]), style_id)

    return datasets, style_ids


def create_style_conditioned_dataset(datasets, style_ids, batch_size=32):
    """Simplified dataset creation without mixed precision"""
    all_data = []
    all_style_ids = []

    for style_name in datasets.keys():
        all_data.append(datasets[style_name])
        all_style_ids.append(style_ids[style_name])

    # Convert to numpy arrays
    all_data = np.concatenate(all_data, axis=0).astype(np.float32)  # Use float32
    all_style_ids = np.concatenate(all_style_ids, axis=0).astype(np.int32)

    # Create dataset
    dataset = tf.data.Dataset.from_tensor_slices((all_data, all_style_ids))
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    # Verify
    for batch in dataset.take(1):
        print("\nDataset sample:")
        print(f"Data shape: {batch[0].shape}, dtype: {batch[0].dtype}")
        print(f"Style shape: {batch[1].shape}, dtype: {batch[1].dtype}")

    return dataset

## Createand train the style conditioned model

In [19]:
def create_and_compile_model_style_conditioned(sequence_length=32,
                                            input_dim=128,
                                            latent_dim=256,
                                            intermediate_dim=512,
                                            num_styles=2):
    """Create and compile the style-conditioned MIDI-VAE model"""
    model = StyleConditionedMIDI_VAE(
        sequence_length,
        input_dim,
        latent_dim,
        intermediate_dim,
        num_styles
    )

    # Test model build
    try:
        dummy_input = (tf.zeros((1, sequence_length, input_dim), dtype=tf.float32),
                      tf.zeros((1,), dtype=tf.int32))
        _ = model(dummy_input)
        print("✓ Model successfully built with:")
        print(f"- Input shape: {dummy_input[0].shape}")
        print(f"- Style shape: {dummy_input[1].shape}")
    except Exception as e:
        print(f"❌ Model building failed: {str(e)}")
        raise

    optimizer = keras.optimizers.Adam(learning_rate=1e-3)
    return model, optimizer

##**Exécution complète du style conditioning**

In [20]:
  import gc

In [21]:

def run_style_conditioning():
    # 1. Define paths and style mapping
    midi_folder_path = "/content/drive/MyDrive/MIDI-VAE-NEW/data"
    style_map = {'Jazz': 0, 'Pop': 1}

    # 2. Load and preprocess data
    print("Loading and preprocessing data...")
    datasets, style_ids = load_style_data(midi_folder_path, style_map, preprocessor)

    # If no datasets found, use dummy data
    if not datasets:
        print("No datasets found. Using dummy data for testing...")
        sequence_length = 32
        input_dim = preprocessor.n_pitches if hasattr(preprocessor, 'n_pitches') else 88
        datasets = {
            'Jazz': np.random.rand(100, sequence_length, input_dim).astype(np.float32),
            'Pop': np.random.rand(100, sequence_length, input_dim).astype(np.float32)
        }
        style_ids = {
            'Jazz': np.full(100, 0),
            'Pop': np.full(100, 1)
        }

    # 3. Create TensorFlow dataset
    print("Creating TensorFlow dataset...")
    combined_dataset = create_style_conditioned_dataset(datasets, style_ids)

    # 4. Define model parameters
    first_batch = next(iter(combined_dataset))
    sequence_length = first_batch[0].shape[1]
    input_dim = first_batch[0].shape[2]
    latent_dim = 64 #256
    intermediate_dim = 128 # 512
    num_styles = len(style_map)

    print(f"\nModel parameters:")
    print(f"- Sequence length: {sequence_length}")
    print(f"- Input dimension: {input_dim}")
    print(f"- Latent dimension: {latent_dim}")
    print(f"- Number of styles: {num_styles}")

    # 5. Create and compile model
    print("\nCreating and compiling model...")
    model_style, optimizer_style = create_and_compile_model_style_conditioned(
        sequence_length=sequence_length,
        input_dim=input_dim,
        latent_dim=latent_dim,
        intermediate_dim=intermediate_dim,
        num_styles=num_styles
    )

    # 6. Setup training
    checkpoint_dir = "/content/drive/MyDrive/MIDI-VAE-NEW/checkpoints_style"
    os.makedirs(checkpoint_dir, exist_ok=True)

    print("\nStarting training...")
    trainer = StyleConditionedVAETrainer(model_style, checkpoint_dir, optimizer_style)
    trainer.train(combined_dataset, epochs=3,save_freq=1000 )

    return model_style, datasets

In [None]:
# Run this before calling run_style_conditioning()


# 1. Initialize

_, preprocessor = test_midi_preprocessing(midi_folder_path, check_subfolders=True)
# 2. Run with checks
try:
    model_style, style_datasets = run_style_conditioning()
except Exception as e:
    print(f"Error: {e}")
    print("Falling back to minimal working example...")
    # Create tiny valid model
    model_style = StyleConditionedMIDI_VAE(32, 88, 64, 128, 2)
    style_datasets = {}

=== MIDI Preprocessing Test ===

1. Checking MIDI files...
✓ Found 99 MIDI files

Sample files:
- Pop/Bobby_Vinton_-_Sealed_With_a_Kiss.mid
- Pop/Boyzone_-_Fathers_And_Sons.mid
- Pop/Bonnie_Tyler_-_Total_Eclipse_of_the_Heart.mid

2. Initializing preprocessor...

3. Testing file processing...
Testing file: Bobby_Vinton_-_Sealed_With_a_Kiss.mid
✓ Successfully processed:
- Duration: 146.51 sec
- Extracted 1141 sequences
- Sequence shape: (32, 89)
Loading and preprocessing data...


Processing Jazz:  76%|███████▌  | 38/50 [00:13<00:04,  2.95it/s]

Error loading /content/drive/MyDrive/MIDI-VAE-NEW/data/Jazz/a_taste_of_honey_jc2.mid: MIDI file has a largest tick of 11839910, it is likely corrupt


Processing Jazz: 100%|██████████| 50/50 [00:19<00:00,  2.62it/s]
Processing Pop: 100%|██████████| 49/49 [00:10<00:00,  4.64it/s]


Creating TensorFlow dataset...

Dataset sample:
Data shape: (32, 32, 89), dtype: <dtype: 'float32'>
Style shape: (32,), dtype: <dtype: 'int32'>

Model parameters:
- Sequence length: 32
- Input dimension: 89
- Latent dimension: 256
- Number of styles: 2

Creating and compiling model...
✓ Model successfully built with:
- Input shape: (1, 32, 89)
- Style shape: (1,)

Starting training...
Initializing from scratch

Epoch 1/3
Step 0: Loss=0.6999, Recon=0.6989
Step 100: Loss=0.1582, Recon=0.1581
Step 200: Loss=0.1581, Recon=0.1580
Step 300: Loss=0.1775, Recon=0.1775
Step 400: Loss=0.1630, Recon=0.1630
Step 500: Loss=0.1566, Recon=0.1566
Step 600: Loss=0.1670, Recon=0.1670
Step 700: Loss=0.2084, Recon=0.2084
Step 800: Loss=0.2062, Recon=0.2062
Step 900: Loss=0.2134, Recon=0.2134
Step 1000: Loss=0.2068, Recon=0.2067
Step 1100: Loss=0.1937, Recon=0.1937
Step 1200: Loss=0.1708, Recon=0.1706
Step 1300: Loss=0.2128, Recon=0.2089
Step 1400: Loss=0.1772, Recon=0.1757
Step 1500: Loss=0.1715, Recon=0.

##

#Generation de musique

In [None]:
def generate_music_from_model(model, style_id, output_path='generated_song.mid', temperature=1.0):
    import pretty_midi

    print(f"Generating music in style ID: {style_id}...")

    # 1. Générer une séquence
    generated = model.generate_with_style(style_id=style_id, num_samples=1, temperature=temperature)
    generated = generated.numpy()[0]  # retirer le batch

    # 2. Convertir le piano roll en PrettyMIDI
    pm = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(program=0)  # Piano

    fs = int(1 / 0.125)  # 8 temps par beat
    threshold = 0.3  # Binarisation du piano roll

    for pitch in range(generated.shape[1]):
        active = False
        note_on = 0
        for t in range(generated.shape[0]):
            if generated[t, pitch] > threshold and not active:
                active = True
                note_on = t
            elif generated[t, pitch] <= threshold and active:
                note_off = t
                start = note_on / fs
                end = note_off / fs
                note = pretty_midi.Note(
                    velocity=100,
                    pitch=pitch,
                    start=start,
                    end=end
                )
                instrument.notes.append(note)
                active = False

    pm.instruments.append(instrument)

    # 3. Sauvegarder en MIDI
    pm.write(output_path)
    print(f"Generated music saved at: {output_path}")


In [None]:
model_style, style_datasets = run_style_conditioning()

# Génère un morceau en style Jazz (ID = 0)
generate_music_from_model(model_style, style_id=0, output_path="jazz_generated.mid")

# Génère un morceau en style Pop (ID = 1)
generate_music_from_model(model_style, style_id=1, output_path="pop_generated.mid")
