In [None]:
import tensorflow as tf
import numpy as np
from magenta.models.music_vae import configs
from magenta.models.music_vae.trained_model import TrainedModel # We need this class
import os

tf.compat.v1.disable_v2_behavior()

LATENT_DIM = 512

# --- Example Usage ---

# --- Configuration ---
MODEL_NAME = "mel_2bar_big"
CHECKPOINT_DIR = "models/download.magenta.tensorflow.org/models/music_vae/checkpoints"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, f"{MODEL_NAME}.ckpt")

# --- Correct parameters for 'mel_2bar_big' ---

#




# Use tensorflow.compat.v1 and disable V2 behavior for the original model




# ==============================================================================
# 1. SETUP & MODEL LOADING
# ==============================================================================

print("--- Step 1: Loading original TF1-style MusicVAE model ---")
mel_2bar_config = configs.CONFIG_MAP['cat-mel_2bar_big']
BASE_DIR = "models/download.magenta.tensorflow.org/models/music_vae"
checkpoint_path = BASE_DIR + '/checkpoints/mel_2bar_big.ckpt'

# Use a batch size of 1 for easier comparison
LATENT_DIM = mel_2bar_config.hparams.z_size
SEQUENCE_LENGTH = 32 # Define the desired generation length
BATCH_SIZE = 1
VOCAB_SIZE = 90
mel_2bar = TrainedModel(mel_2bar_config, batch_size=BATCH_SIZE, checkpoint_dir_or_path=checkpoint_path)
print("Original model loaded.")

graph = mel_2bar._sess.graph

# --- Use the exact names discovered from your debugging ---
# The z placeholder with shape (1, 512)
Z_PLACEHOLDER_NAME = 'Placeholder_1:0'

# The output logits tensor from the sampling graph
LOGITS_TENSOR_NAME = 'sample/decoder/rnn_output:0'



model_blueprint = mel_2bar_config.model
decoder_blueprint = model_blueprint.decoder

    # Retrive the relevant elements of the graph
temperature_placeholder = graph.get_tensor_by_name('Placeholder:0')
z_placeholder = graph.get_tensor_by_name('Placeholder_1:0')
inputs_placeholder = graph.get_tensor_by_name('Placeholder_2:0')
controls_placeholder = graph.get_tensor_by_name('Placeholder_3:0')
inputs_length_placeholder = graph.get_tensor_by_name('Placeholder_4:0')
output_length_placeholder = graph.get_tensor_by_name('Placeholder_5:0') # The final placeholder
logits_tensor = graph.get_tensor_by_name('decoder/TensorArrayStack_1/TensorArrayGatherV3:0')


np.random.seed(42)
z_np = np.random.randn(BATCH_SIZE, LATENT_DIM).astype(np.float32)

# Dummy inputs to satisfy the graph's requirements, based on your debugging
dummy_inputs = np.zeros((BATCH_SIZE, SEQUENCE_LENGTH, VOCAB_SIZE), dtype=np.float32)
dummy_inputs_length = np.array([SEQUENCE_LENGTH] * BATCH_SIZE, dtype=np.int32)
dummy_controls = np.zeros((BATCH_SIZE, SEQUENCE_LENGTH, 0), dtype=np.float32)

# Construct the full, correct feed dictionary
feed_dict = {
    temperature_placeholder: 0, # We don't need the temperature here other than as a dummy
    z_placeholder: z_np,
    inputs_placeholder: dummy_inputs,
    inputs_length_placeholder: dummy_inputs_length,
    controls_placeholder: dummy_controls,
    output_length_placeholder: SEQUENCE_LENGTH # The final missing piece
}



logits_tf1 = mel_2bar._sess.run(
    logits_tensor,
    feed_dict=feed_dict
)

print("\nLogits shape from TF1 model:", logits_tf1.shape)

print("\nLogits for the very first step (first 5 values):")
print(logits_tf1[0, 0, :5])




# --- Main Comparison Logic ---

In [None]:

# This is for debugging purposes. 
# If you want valid values for tf1, don't forget to include tf.compat.v1.disable_v2_behavior()

def inspect_checkpoint(checkpoint_path):
    """
    A helper function to print all variable names and their shapes in a checkpoint.
    This is extremely useful for debugging name-related errors.
    """
    print(f"\n--- Inspecting variables in checkpoint: {checkpoint_path} ---")
    try:
        reader = tf.train.load_checkpoint(checkpoint_path)
        shape_map = reader.get_variable_to_shape_map()
        for key in sorted(shape_map.keys()):
            print(f"Tensor name: {key}, shape: {shape_map[key]}")
    except Exception as e:
        print(f"Could not read checkpoint: {e}")
    print("--------------------------------------------------------\n")

MODEL_NAME = "mel_2bar_big"
CHECKPOINT_DIR = "models/download.magenta.tensorflow.org/models/music_vae/checkpoints"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, f"{MODEL_NAME}.ckpt")

inspect_checkpoint(CHECKPOINT_PATH)


def get_tensor_names_from_graph(graph):
    """
    A helper function to print all tensor names in a TensorFlow graph.
    This helps identify the correct names to use when accessing tensors.
    """
    print("\n--- Inspecting tensors in the graph ---")
    for index in range(len(graph.get_operations())):
        op = graph.get_operations()[index]
        print(f"Operation name: {op.name}"+"\n")
        for tensor in op.values():
            print(f"Tensor name: {tensor.name}, shape: {tensor.shape}"+"\n")
    print("--------------------------------------------------------\n")


get_tensor_names_from_graph(graph)


--- Inspecting variables in checkpoint: models/download.magenta.tensorflow.org/models/music_vae/checkpoints\mel_2bar_big.ckpt ---
Could not read checkpoint: name 'tf' is not defined
--------------------------------------------------------


--- Inspecting tensors in the graph ---
Operation name: global_step/Initializer/zeros

Tensor name: global_step/Initializer/zeros:0, shape: ()

Operation name: global_step

Tensor name: global_step:0, shape: ()

Operation name: global_step/IsInitialized/VarIsInitializedOp

Tensor name: global_step/IsInitialized/VarIsInitializedOp:0, shape: ()

Operation name: global_step/Assign

Operation name: global_step/Read/ReadVariableOp

Tensor name: global_step/Read/ReadVariableOp:0, shape: ()

Operation name: DropoutWrapperInit/Const

Tensor name: DropoutWrapperInit/Const:0, shape: ()

Operation name: DropoutWrapperInit/Const_1

Tensor name: DropoutWrapperInit/Const_1:0, shape: ()

Operation name: DropoutWrapperInit/Const_2

Tensor name: DropoutWrapperInit/