In [1]:
%reload_ext autoreload
%autoreload 2

import tensorflow as tf
import numpy as np
from model_def.decoder import MusicVAEDecoder, AutoregressiveStep

# --- Configuration ---
MODEL_DIR = "models/music_vae_decoder_keras"
LATENT_DIM = 512
SEQUENCE_LENGTH = 32 # The sequence length you used when saving the model

# --- 1. Load the Keras SavedModel ---
print(f"Loading Keras model from: {MODEL_DIR}")
# Create the dictionary of custom objects
custom_objects = {
    "MusicVAEDecoder": MusicVAEDecoder,
    "AutoregressiveStep": AutoregressiveStep
}


# This loads the entire model, including the architecture and the traced signatures.
loaded_keras_model = tf.keras.models.load_model(MODEL_DIR,custom_objects=custom_objects)
print("Keras model loaded successfully.")

# --- 2. Access the 'generate' signature ---
# The signatures dictionary holds the pre-traced functions we saved.
generate_signature = loaded_keras_model.signatures['generate']
print("Located 'generate' signature.")
print("Signature inputs:", list(generate_signature.structured_input_signature[1].keys()))
print("Signature outputs:", list(generate_signature.structured_outputs.keys()))









  




Loading Keras model from: models/music_vae_decoder_keras
Keras model loaded successfully.
Located 'generate' signature.
Signature inputs: ['z']
Signature outputs: ['output_0']


In [2]:
# --- 3. Create a random latent vector ---
# We'll use a fixed seed to ensure this input is reproducible.
np.random.seed(42)
z_input = np.random.randn(1, LATENT_DIM).astype(np.float32)
print(f"\nGenerated random latent vector 'z' with shape: {z_input.shape}")



Generated random latent vector 'z' with shape: (1, 512)


In [5]:
# --- 4. Run inference ---
print("Running inference with the Keras model...")
# The signature expects a TensorFlow constant.
# The output is a dictionary, as defined by the signature's outputs.
keras_output_dict = generate_signature(z=tf.constant(z_input))

# The output key might be 'output_0' or a more descriptive name.
# We inspect the dictionary keys to find the correct one.
output_key = list(keras_output_dict.keys())[0]
keras_logits = keras_output_dict[output_key].numpy()

print("Inference complete.")
print(f"Keras model output logits shape: {keras_logits.shape}")

print(keras_logits[0, 0, :5])

#musicvae sequence [ 4.8953643  2.3632274 -6.4150443 -8.422484  -7.8997383]



Running inference with the Keras model...
Inference complete.
Keras model output logits shape: (1, 32, 90)
[-0.21730028 -3.5196676  -1.7660146  -2.4989333  -2.6764984 ]


In [6]:
import tensorflow as tf
import numpy as np
from magenta.models.music_vae import configs
from magenta.models.music_vae.trained_model import TrainedModel # We need this class
import os

tf.compat.v1.disable_v2_behavior()

LATENT_DIM = 512

# --- Example Usage ---

# --- Configuration ---
MODEL_NAME = "mel_2bar_big"
CHECKPOINT_DIR = "models/download.magenta.tensorflow.org/models/music_vae/checkpoints"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, f"{MODEL_NAME}.ckpt")

# --- Correct parameters for 'mel_2bar_big' ---

# Use tensorflow.compat.v1 and disable V2 behavior for the original model

# ==============================================================================
# 1. SETUP & MODEL LOADING
# ==============================================================================

print("--- Step 1: Loading original TF1-style MusicVAE model ---")
mel_2bar_config = configs.CONFIG_MAP['cat-mel_2bar_big']
BASE_DIR = "models/download.magenta.tensorflow.org/models/music_vae"
checkpoint_path = BASE_DIR + '/checkpoints/mel_2bar_big.ckpt'

# Use a batch size of 1 for easier comparison
LATENT_DIM = mel_2bar_config.hparams.z_size
SEQUENCE_LENGTH = 32 # Define the desired generation length
BATCH_SIZE = 1
VOCAB_SIZE = 90
mel_2bar = TrainedModel(mel_2bar_config, batch_size=BATCH_SIZE, checkpoint_dir_or_path=checkpoint_path)
print("Original model loaded.")

graph = mel_2bar._sess.graph

# --- Use the exact names discovered from your debugging ---
# The z placeholder with shape (1, 512)
Z_PLACEHOLDER_NAME = 'Placeholder_1:0'

# The output logits tensor from the sampling graph
LOGITS_TENSOR_NAME = 'sample/decoder/rnn_output:0'



model_blueprint = mel_2bar_config.model
decoder_blueprint = model_blueprint.decoder

    # Retrive the relevant elements of the graph
temperature_placeholder = graph.get_tensor_by_name('Placeholder:0')
z_placeholder = graph.get_tensor_by_name('Placeholder_1:0')
inputs_placeholder = graph.get_tensor_by_name('Placeholder_2:0')
controls_placeholder = graph.get_tensor_by_name('Placeholder_3:0')
inputs_length_placeholder = graph.get_tensor_by_name('Placeholder_4:0')
output_length_placeholder = graph.get_tensor_by_name('Placeholder_5:0') # The final placeholder
logits_tensor = graph.get_tensor_by_name('decoder/TensorArrayStack_1/TensorArrayGatherV3:0')




# Dummy inputs to satisfy the graph's requirements, based on your debugging
dummy_inputs = np.zeros((BATCH_SIZE, SEQUENCE_LENGTH, VOCAB_SIZE), dtype=np.float32)
dummy_inputs_length = np.array([SEQUENCE_LENGTH] * BATCH_SIZE, dtype=np.int32)
dummy_controls = np.zeros((BATCH_SIZE, SEQUENCE_LENGTH, 0), dtype=np.float32)

# Construct the full, correct feed dictionary
feed_dict = {
    temperature_placeholder: 0, # We don't need the temperature here other than as a dummy
    z_placeholder: z_input,
    inputs_placeholder: dummy_inputs,
    inputs_length_placeholder: dummy_inputs_length,
    controls_placeholder: dummy_controls,
    output_length_placeholder: SEQUENCE_LENGTH # The final missing piece
}



logits_tf1 = mel_2bar._sess.run(
    logits_tensor,
    feed_dict=feed_dict
)

print("\nLogits shape from TF1 model:", logits_tf1.shape)

print("\nLogits for the very first step (first 5 values):")
print(logits_tf1[0, 0, :5])













# --- Main Comparison Logic ---

Instructions for updating:
non-resource variables are not supported in the long term
--- Step 1: Loading original TF1-style MusicVAE model ---
INFO:tensorflow:Building MusicVAE model with BidirectionalLstmEncoder, CategoricalLstmDecoder, and hparams:
{'max_seq_len': 32, 'z_size': 512, 'free_bits': 0, 'max_beta': 0.5, 'beta_rate': 0.99999, 'batch_size': 1, 'grad_clip': 1.0, 'clip_mode': 'global_norm', 'grad_norm_clip_to_zero': 10000, 'learning_rate': 0.001, 'decay_rate': 0.9999, 'min_learning_rate': 1e-05, 'conditional': True, 'dec_rnn_size': [2048, 2048, 2048], 'enc_rnn_size': [2048], 'dropout_keep_prob': 1.0, 'sampling_schedule': 'inverse_sigmoid', 'sampling_rate': 1000, 'use_cudnn': False, 'residual_encoder': False, 'residual_decoder': False, 'control_preprocessing_rnn_size': [256]}
INFO:tensorflow:
Encoder Cells (bidirectional):
  units: [2048]

INFO:tensorflow:
Decoder Cells:
  units: [2048, 2048, 2048]

Instructions for updating:
Use `tf.cast` instead.


  tf.layers.dense(
  self._kernel = self.add_variable(
  self._bias = self.add_variable(


Instructions for updating:
Please use `keras.layers.Bidirectional(keras.layers.RNN(cell))`, which is equivalent to this API
Instructions for updating:
Please use `keras.layers.RNN(cell)`, which is equivalent to this API
Instructions for updating:
`scale_identity_multiplier` is deprecated; please combine it into `scale_diag` directly instead.
INFO:tensorflow:Restoring parameters from models/download.magenta.tensorflow.org/models/music_vae/checkpoints/mel_2bar_big.ckpt


  mu = tf.layers.dense(
  sigma = tf.layers.dense(


Original model loaded.

Logits shape from TF1 model: (32, 1, 90)

Logits for the very first step (first 5 values):
[ 4.8953643  2.3632274 -6.4150443 -8.422484  -7.8997383]


In [7]:
musicvae_embeddings= logits_tf1[0,0,:]
keras_embedding= keras_logits[0,0,:]

dist_musicvae_vs_keras = np.linalg.norm(musicvae_embeddings - keras_embedding)
print(dist_musicvae_vs_keras)

44.072735


In [None]:

# This is for debugging purposes. 
# If you want valid values for tf1, don't forget to include tf.compat.v1.disable_v2_behavior()

def inspect_checkpoint(checkpoint_path):
    """
    A helper function to print all variable names and their shapes in a checkpoint.
    This is extremely useful for debugging name-related errors.
    """
    print(f"\n--- Inspecting variables in checkpoint: {checkpoint_path} ---")
    try:
        reader = tf.train.load_checkpoint(checkpoint_path)
        shape_map = reader.get_variable_to_shape_map()
        for key in sorted(shape_map.keys()):
            print(f"Tensor name: {key}, shape: {shape_map[key]}")
    except Exception as e:
        print(f"Could not read checkpoint: {e}")
    print("--------------------------------------------------------\n")

MODEL_NAME = "mel_2bar_big"
CHECKPOINT_DIR = "models/download.magenta.tensorflow.org/models/music_vae/checkpoints"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, f"{MODEL_NAME}.ckpt")

inspect_checkpoint(CHECKPOINT_PATH)


def get_tensor_names_from_graph(graph):
    """
    A helper function to print all tensor names in a TensorFlow graph.
    This helps identify the correct names to use when accessing tensors.
    """
    print("\n--- Inspecting tensors in the graph ---")
    for index in range(len(graph.get_operations())):
        op = graph.get_operations()[index]
        print(f"Operation name: {op.name}"+"\n")
        for tensor in op.values():
            print(f"Tensor name: {tensor.name}, shape: {tensor.shape}"+"\n")
    print("--------------------------------------------------------\n")


get_tensor_names_from_graph(graph)