In [None]:
import tensorflow as tf
from tensorflow.keras.layers import LSTMCell, RNN, Dense
import os
import requests

# The MusicVAEDecoder class remains the same
class MusicVAEDecoder(tf.keras.Model):
    """The decoder portion of the MusicVAE model."""
    def __init__(self, output_depth, lstm_units=2048, num_layers=2, name="decoder"):
        super(MusicVAEDecoder, self).__init__(name=name)
        self.z_to_initial_state = Dense(lstm_units * num_layers * 2, name="z_to_initial_state")
        self.lstm_cells = [LSTMCell(lstm_units, name=f"lstm_cell_{i}") for i in range(num_layers)]
        self.rnn = RNN(self.lstm_cells, return_sequences=True, return_state=True, name="decoder_rnn")
        self.output_projection = Dense(output_depth, name="output_projection")

    # In the MusicVAEDecoder class:

    def call(self, z, sequence_length, inputs=None):
        """
        Performs the forward pass of the decoder.
        """
        batch_size = tf.shape(z)[0]
        num_layers = len(self.lstm_cells)
        lstm_units = self.lstm_cells[0].units

        # Project the latent vector 'z' to get the initial state for the LSTM.
        # Shape: (batch_size, num_layers * 2 * lstm_units)
        initial_state_flat = self.z_to_initial_state(z)

        # Reshape to separate layers and the h/c states.
        # Shape: (batch_size, num_layers, 2, lstm_units)
        initial_state_reshaped = tf.reshape(
            initial_state_flat, [batch_size, num_layers, 2, lstm_units]
        )

        # Transpose to group h/c states by layer.
        # Shape: (num_layers, 2, batch_size, lstm_units)
        initial_state_transposed = tf.transpose(initial_state_reshaped, [1, 2, 0, 3])

        # Unstack to create the final list of states for each layer.
        # This creates a list of `num_layers` elements.
        # Each element is a tensor of shape (2, batch_size, lstm_units).
        initial_state_list = tf.unstack(initial_state_transposed)

        # Further unstack each layer's state into (h, c) tuples.
        # The final structure is: [ (h0, c0), (h1, c1), ... ]
        # which is what the Keras RNN layer expects.
        initial_state = [tf.unstack(s) for s in initial_state_list]

                # If no input sequence is provided, create the dummy one for generation.
        # IMPORTANT: The dummy input must have the correct dimension.
        if inputs is None:
            # The input to the first LSTM is z + previous_output
            # which has a dimension of 512 + 90 = 602
            input_depth = 602 
            inputs = tf.zeros([batch_size, sequence_length, input_depth])

        # Run the RNN.
        rnn_output, *_ = self.rnn(inputs, initial_state=initial_state)

        # Project the RNN output to the final output space.
        output = self.output_projection(rnn_output)

        return output




def load_magenta_weights(decoder_model, checkpoint_path):
    """
    Loads weights from a TF1 Magenta checkpoint into a TF2 Keras decoder model.
    (Final corrected version)
    """
    reader = tf.train.load_checkpoint(checkpoint_path)

    # --- 1. Load z_to_initial_state weights ---
    z_kernel = reader.get_tensor("decoder/z_to_initial_state/kernel")
    z_bias = reader.get_tensor("decoder/z_to_initial_state/bias")
    decoder_model.z_to_initial_state.set_weights([z_kernel, z_bias])
    print("Loaded weights for 'z_to_initial_state' layer.")

    # --- 2. Load LSTM cell weights ---
    for i, cell in enumerate(decoder_model.lstm_cells):
        tf1_kernel_name = f"decoder/multi_rnn_cell/cell_{i}/lstm_cell/kernel"
        tf1_bias_name = f"decoder/multi_rnn_cell/cell_{i}/lstm_cell/bias"
        
        tf1_kernel = reader.get_tensor(tf1_kernel_name)
        tf1_bias = reader.get_tensor(tf1_bias_name)

        # THE FIX: Use the correct input dimension for splitting the kernel,
        # based on the layer index.
        if i == 0:
            # The original model's first layer has a complex input of dim 602.
            input_dim = 602
        else:
            # Subsequent layers take the output of the previous LSTM layer.
            input_dim = cell.units # which is 2048

        # Perform the split at the correct index.
        keras_kernel = tf1_kernel[:input_dim, :]
        keras_recurrent_kernel = tf1_kernel[input_dim:, :]
        
        # Now the shapes will match perfectly.
        cell.set_weights([keras_kernel, keras_recurrent_kernel, tf1_bias])
        print(f"Loaded weights for LSTM cell {i} from '{tf1_kernel_name}'.")

    # --- 3. Load output_projection weights ---
    out_kernel = reader.get_tensor("decoder/output_projection/kernel")
    out_bias = reader.get_tensor("decoder/output_projection/bias")
    decoder_model.output_projection.set_weights([out_kernel, out_bias])
    print("Loaded weights for 'output_projection' layer.")

    print("\nSuccessfully loaded all decoder weights from Magenta checkpoint!")




# --- Example Usage ---

# --- Configuration ---
MODEL_NAME = "mel_2bar_big"
CHECKPOINT_DIR = "models/download.magenta.tensorflow.org/models/music_vae/checkpoints"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, f"{MODEL_NAME}.ckpt")

# --- Correct parameters for 'mel_2bar_big' ---
LATENT_DIM = 512
OUTPUT_DEPTH = 90
SEQUENCE_LENGTH = 32 # The correct, fixed sequence length for this model
LSTM_UNITS = 2048
NUM_LAYERS = 3

# --- Main execution ---
    
# 2. Instantiate our TF2 Keras decoder model with the correct number of layers.
decoder = MusicVAEDecoder(
    output_depth=OUTPUT_DEPTH,
    lstm_units=LSTM_UNITS,
    num_layers=NUM_LAYERS
)

# 3. Build the model by calling it once.
print("\nBuilding Keras model to initialize variables...")
dummy_z = tf.zeros([1, LATENT_DIM])
# We must use the correct sequence length here to build the graph properly.
decoder(dummy_z, sequence_length=SEQUENCE_LENGTH)
print("Model built.")

# 4. Load the weights from the downloaded checkpoint.
load_magenta_weights(decoder, CHECKPOINT_PATH)

# 5. Verify the model can run inference with the correct dimensions.
print("\nRunning a test inference with correct parameters...")
batch_size = 1
z = tf.random.normal([batch_size, LATENT_DIM])
    
# Generate a sequence of the correct length
generated_sequence = decoder(z, sequence_length=SEQUENCE_LENGTH)
    
print("Shape of the generated sequence:", generated_sequence.shape)
print("Inference successful!")




In [None]:
def inspect_checkpoint(checkpoint_path):
    """
    A helper function to print all variable names and their shapes in a checkpoint.
    This is extremely useful for debugging name-related errors.
    """
    print(f"\n--- Inspecting variables in checkpoint: {checkpoint_path} ---")
    try:
        reader = tf.train.load_checkpoint(checkpoint_path)
        shape_map = reader.get_variable_to_shape_map()
        for key in sorted(shape_map.keys()):
            print(f"Tensor name: {key}, shape: {shape_map[key]}")
    except Exception as e:
        print(f"Could not read checkpoint: {e}")
    print("--------------------------------------------------------\n")

MODEL_NAME = "mel_2bar_big"
CHECKPOINT_DIR = "models/download.magenta.tensorflow.org/models/music_vae/checkpoints"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, f"{MODEL_NAME}.ckpt")

inspect_checkpoint(CHECKPOINT_PATH)



