In [1]:
import tensorflow as tf


# --- IMPORTANT: Set this to your actual checkpoint path ---
BASE_DIR = "models/download.magenta.tensorflow.org/models/music_vae"
CHECKPOINT_PATH = BASE_DIR + '/checkpoints/mel_2bar_big.ckpt'




import numpy as np

def debug_first_lstm_cell(decoder_model, z, first_step_input):
    """
    Performs a manual forward pass for the first LSTM cell at the first timestep
    and compares its output to the Keras layer's output.

    Args:
        decoder_model: Your loaded Keras decoder model.
        z: The latent vector, shape [1, latent_dim].
        first_step_input: The first element of the teacher sequence, shape [1, 1, output_depth].
    """
    print("\n--- Debugging First LSTM Cell (Manual vs. Keras) ---")

    # 1. Get the initial state for the first cell from the Keras model
    # This tests your _get_initial_state logic implicitly.
    initial_state_list = decoder_model._get_initial_state(z)
    h_0, c_0 = initial_state_list[0]  # State for the first layer

    # 2. Prepare the single-step input for the first cell
    # This is `concat([input_step_0, z])`
    z_repeated_step = tf.tile(tf.expand_dims(z, 1), [1, 1, 1])
    rnn_input_step = tf.concat([first_step_input, z_repeated_step], axis=-1)
    rnn_input_step = tf.squeeze(rnn_input_step, axis=1)  # Shape: [batch, features]

    # 3. Get the output from the Keras LSTMCell layer directly
    first_cell = decoder_model.lstm_cells[0]
    keras_output, (keras_h_1, keras_c_1) = first_cell(rnn_input_step, states=[h_0, c_0])

    # 4. Perform the same calculation MANUALLY using the cell's loaded weights
    weights = first_cell.get_weights()
    if len(weights) != 3:
        print("ERROR: First LSTM cell does not have 3 weights (kernel, recurrent_kernel, bias).")
        return

    kernel, recurrent_kernel, bias = weights

    # Manual LSTM math: gates = (x @ W) + (h @ U) + b
    gate_inputs = tf.matmul(rnn_input_step, kernel)
    gate_recurrent = tf.matmul(h_0, recurrent_kernel)
    gates = gate_inputs + gate_recurrent + bias

    # Split into i, f, c, o gates (assuming Keras [i, f, c, o] order)
    i, f, c_tilde, o = tf.split(gates, 4, axis=-1)

    # Apply activations
    i = tf.sigmoid(i)
    f = tf.sigmoid(f)
    c_tilde = tf.tanh(c_tilde)
    o = tf.sigmoid(o)

    # Calculate new cell state and hidden state
    manual_c_1 = f * c_0 + i * c_tilde
    manual_h_1 = o * tf.tanh(manual_c_1)

    # 5. Compare the results
    h_dist = tf.reduce_mean(tf.square(keras_h_1 - manual_h_1))
    c_dist = tf.reduce_mean(tf.square(keras_c_1 - manual_c_1))

    print(f"Distance on hidden state (h_1): {h_dist.numpy():.10f}")
    print(f"Distance on cell state (c_1):   {c_dist.numpy():.10f}")

    if h_dist < 1e-9 and c_dist < 1e-9:
        print("\nSUCCESS: The first LSTM cell's forward pass matches the manual calculation.")
        print("This means the weights for cell_0 are loaded correctly and the problem is likely in a subsequent cell (cell_1) or the final output_projection.")
    else:
        print("\nFAILURE: Mismatch in the first LSTM cell's calculation.")
        print("This strongly suggests the weight loading logic for the LSTM (gate order, forget bias, or kernel split) is still incorrect.")



%reload_ext autoreload
%autoreload 2

import tensorflow as tf
import numpy as np
from model_def.decoder import MusicVAEDecoder, AutoregressiveStep

# --- Configuration ---
MODEL_DIR = "models/music_vae_decoder_keras"
LATENT_DIM = 512
SEQUENCE_LENGTH = 32 # The sequence length you used when saving the model

# --- 1. Load the Keras SavedModel ---
print(f"Loading Keras model from: {MODEL_DIR}")
# Create the dictionary of custom objects
custom_objects = {
    "MusicVAEDecoder": MusicVAEDecoder,
    "AutoregressiveStep": AutoregressiveStep
}


# This loads the entire model, including the architecture and the traced signatures.
loaded_keras_model = tf.keras.models.load_model(MODEL_DIR,custom_objects=custom_objects)
print("Keras model loaded successfully.")

# --- 2. Access the 'generate' signature ---
# The signatures dictionary holds the pre-traced functions we saved.
generate_signature = loaded_keras_model.signatures['generate']
print("Located 'generate' signature.")
print("Signature inputs:", list(generate_signature.structured_input_signature[1].keys()))
print("Signature outputs:", list(generate_signature.structured_outputs.keys()))


# Assuming `my_decoder` is your model instance after calling load_magenta_weights

# Create a deterministic, non-zero input for testing
z_sample = tf.random.normal([1, loaded_keras_model.latent_dim], seed=42)
teacher_input_sample = tf.random.uniform([1, loaded_keras_model.sequence_length, loaded_keras_model.output_depth], seed=42)

# Extract the very first step of the teacher input
first_step = teacher_input_sample[:, :1, :] 

# Run the debugging function
debug_first_lstm_cell(loaded_keras_model, z_sample, first_step)





Loading Keras model from: models/music_vae_decoder_keras
Keras model loaded successfully.
Located 'generate' signature.
Signature inputs: ['z']
Signature outputs: ['output_0']

--- Debugging First LSTM Cell (Manual vs. Keras) ---
Distance on hidden state (h_1): 0.0000000000
Distance on cell state (c_1):   0.0000000000

SUCCESS: The first LSTM cell's forward pass matches the manual calculation.
This means the weights for cell_0 are loaded correctly and the problem is likely in a subsequent cell (cell_1) or the final output_projection.


In [2]:
import tensorflow as tf
import numpy as np

def debug_second_lstm_cell(decoder_model, z, first_step_input):
    """
    Performs a manual forward pass for the SECOND LSTM cell at the first timestep
    and compares its output to the Keras layer's output.
    This function assumes the first cell has already been verified.

    Args:
        decoder_model: Your loaded Keras decoder model.
        z: The latent vector, shape [1, latent_dim].
        first_step_input: The first element of the teacher sequence, shape [1, 1, output_depth].
    """
    print("\n--- Debugging Second LSTM Cell (Manual vs. Keras) ---")

    # 1. Get initial states for the first two cells.
    initial_state_list = decoder_model._get_initial_state(z)
    if len(initial_state_list) < 2:
        print("ERROR: Model does not have at least two LSTM layers.")
        return
    h0_cell0, c0_cell0 = initial_state_list[0]
    h0_cell1, c0_cell1 = initial_state_list[1]

    # 2. Calculate the output of the first cell (which is the input to the second).
    # We trust this calculation now, so we just use the Keras layer.
    first_cell = decoder_model.lstm_cells[0]
    z_repeated_step = tf.tile(tf.expand_dims(z, 1), [1, 1, 1])
    rnn_input_step_cell0 = tf.concat([first_step_input, z_repeated_step], axis=-1)
    rnn_input_step_cell0 = tf.squeeze(rnn_input_step_cell0, axis=1)
    
    # This is the input for the second cell.
    input_for_cell1, _ = first_cell(rnn_input_step_cell0, states=[h0_cell0, c0_cell0])

    # 3. Get the output from the Keras LSTMCell layer for cell_1.
    second_cell = decoder_model.lstm_cells[1]
    keras_output_cell1, (keras_h1_cell1, keras_c1_cell1) = second_cell(input_for_cell1, states=[h0_cell1, c0_cell1])

    # 4. Perform the same calculation MANUALLY for cell_1.
    weights = second_cell.get_weights()
    if len(weights) != 3:
        print("ERROR: Second LSTM cell does not have 3 weights (kernel, recurrent_kernel, bias).")
        return

    kernel, recurrent_kernel, bias = weights

    # Manual LSTM math: gates = (x @ W) + (h @ U) + b
    gate_inputs = tf.matmul(input_for_cell1, kernel)
    gate_recurrent = tf.matmul(h0_cell1, recurrent_kernel)
    gates = gate_inputs + gate_recurrent + bias

    # Split into i, f, c, o gates (assuming Keras [i, f, c, o] order)
    i, f, c_tilde, o = tf.split(gates, 4, axis=-1)

    # Apply activations
    i, f, o = tf.sigmoid(i), tf.sigmoid(f), tf.sigmoid(o)
    c_tilde = tf.tanh(c_tilde)

    # Calculate new cell state and hidden state
    manual_c1_cell1 = f * c0_cell1 + i * c_tilde
    manual_h1_cell1 = o * tf.tanh(manual_c1_cell1)

    # 5. Compare the results
    h_dist = tf.reduce_mean(tf.square(keras_h1_cell1 - manual_h1_cell1))
    c_dist = tf.reduce_mean(tf.square(keras_c1_cell1 - manual_c1_cell1))

    print(f"Distance on hidden state (h_1) for cell_1: {h_dist.numpy():.10f}")
    print(f"Distance on cell state (c_1) for cell_1:   {c_dist.numpy():.10f}")

    if h_dist.numpy() < 1e-9 and c_dist.numpy() < 1e-9:
        print("\nSUCCESS: The second LSTM cell's forward pass matches the manual calculation.")
        print("This means weights for cell_1 are also correct. The problem is likely in the final cell or the output_projection layer.")
    else:
        print("\nFAILURE: Mismatch in the second LSTM cell's calculation.")
        print("This suggests the weight loading logic for cell_1 is incorrect, specifically the `input_dim` used for the kernel split.")



# --- In your main execution block ---

# ... (instantiate model, load weights, create sample inputs) ...

# Run the first debug function (which we know succeeds)
# debug_first_lstm_cell(my_decoder, z_sample, first_step)

# Now, run the new debug function for the second cell
debug_second_lstm_cell(loaded_keras_model, z_sample, first_step)



--- Debugging Second LSTM Cell (Manual vs. Keras) ---
Distance on hidden state (h_1) for cell_1: 0.0000000000
Distance on cell state (c_1) for cell_1:   0.0000000000

SUCCESS: The second LSTM cell's forward pass matches the manual calculation.
This means weights for cell_1 are also correct. The problem is likely in the final cell or the output_projection layer.
