In [31]:
import tensorflow as tf
from tensorflow.keras.layers import LSTMCell, RNN, Dense
import os
import requests

LATENT_DIM = 512
OUTPUT_DEPTH = 90
SEQUENCE_LENGTH = 32 # The correct, fixed sequence length for this model
BATCH_SIZE = 1

# The MusicVAEDecoder class remains the same
class MusicVAEDecoder(tf.keras.Model):
    """The decoder portion of the MusicVAE model."""
    def __init__(self, output_depth, lstm_units=2048, num_layers=2, name="decoder"):
        super(MusicVAEDecoder, self).__init__(name=name)
        self.z_to_initial_state = Dense(lstm_units * num_layers * 2, name="z_to_initial_state")
        self.lstm_cells = [LSTMCell(lstm_units, name=f"lstm_cell_{i}") for i in range(num_layers)]
        self.rnn = RNN(self.lstm_cells, return_sequences=True, return_state=True, name="decoder_rnn")
        self.output_projection = Dense(output_depth, name="output_projection")
        self.vocab_size = output_depth

    def call(self, z, sequence_length, inputs=None):
        """
        Performs the forward pass of the decoder.
        """
        batch_size = tf.shape(z)[0]
        num_layers = len(self.lstm_cells)
        lstm_units = self.lstm_cells[0].units

        # Project the latent vector 'z' to get the initial state for the LSTM.
        # Shape: (batch_size, num_layers * 2 * lstm_units)
        initial_state_flat = self.z_to_initial_state(z)

        # Reshape to separate layers and the h/c states.
        # Shape: (batch_size, num_layers, 2, lstm_units)
        initial_state_reshaped = tf.reshape(
            initial_state_flat, [batch_size, num_layers, 2, lstm_units]
        )

        # Transpose to group h/c states by layer.
        # Shape: (num_layers, 2, batch_size, lstm_units)
        initial_state_transposed = tf.transpose(initial_state_reshaped, [1, 2, 0, 3])

        # Unstack to create the final list of states for each layer.
        # This creates a list of `num_layers` elements.
        # Each element is a tensor of shape (2, batch_size, lstm_units).
        initial_state_list = tf.unstack(initial_state_transposed)

        # Further unstack each layer's state into (h, c) tuples.
        # The final structure is: [ (h0, c0), (h1, c1), ... ]
        # which is what the Keras RNN layer expects.
        initial_state = [tf.unstack(s) for s in initial_state_list]


         # --- THE CRUCIAL FIX ---
        # 2. Prepare the latent vector for concatenation at each time step.
        # We need to repeat `z` so it can be attached to every element of the sequence.
        # Tile z from shape [batch, latent_dim] to [batch, sequence_length, latent_dim]
        z_repeated = tf.tile(tf.expand_dims(z, 1), [1, sequence_length, 1])

        

        # This distinction here is very subtle and important. If inputs is provided,
        # we run the RNN in "training" mode (teacher forcing). If inputs is None,
        # we run in "inference" mode (autoregressive generation).
        if inputs is not None:
            # The input to the first LSTM is z 
            # which has a dimension of 512 
            input_depth = 90 
            inputs = tf.zeros([batch_size, sequence_length, input_depth])
        
        # Concatenate z with the inputs along the feature dimension.
        # `inputs` has shape [batch, sequence_length, 90]
            # `z_repeated` has shape [batch, sequence_length, 512]
            # The result will have shape [batch, sequence_length, 602]
            rnn_inputs = tf.concat([inputs, z_repeated], axis=-1)

             # Run the RNN.
            rnn_output, *_ = self.rnn(rnn_inputs, initial_state=initial_state)

            # Project the RNN output to the final output space.
            output = self.output_projection(rnn_output)
            return output
        else:
            # Start with a "start token", which is a zero vector.
            # Shape: [batch_size, vocab_size]
            step_input = tf.zeros([batch_size, self.vocab_size])
            
            # The state will be updated at each step of the loop
            current_state = initial_state
            
            # List to collect the output logits at each step
            all_logits = []

            for _ in range(sequence_length):
                # A. Concatenate the current step's input with z
                # Shape: [batch_size, 90 + 512] -> [batch_size, 602]
                step_input_with_z = tf.concat([step_input, z], axis=-1)

                # B. Run the stacked LSTM cells for a single step
                # We must manually call each cell in the stack.
                # The input to the first cell is our concatenated vector.
                # The input to subsequent cells is the output of the previous cell.
                cell_input = step_input_with_z
                new_states = []
                for i, cell in enumerate(self.lstm_cells):
                    # `cell` returns (output, [new_h, new_c])
                    # `current_state[i]` is the (h, c) tuple for this cell
                    cell_output, (new_h, new_c) = cell(cell_input, states=current_state[i])
                    new_states.append([new_h, new_c])
                    # The output of this cell becomes the input for the next
                    cell_input = cell_output
                
                # The final output of the stack is the last cell's output
                final_cell_output = cell_output
                current_state = new_states # Update the state for the next iteration

                # C. Project the output to get the logits for this single step
                # Shape: [batch_size, 90]
                step_logits = self.output_projection(final_cell_output)
                all_logits.append(step_logits)

                # D. Prepare the input for the *next* step (autoregression)
                # Get the most likely note index from the logits
                next_token_indices = tf.argmax(step_logits, axis=-1)
                # Convert the index back to a one-hot vector
                step_input = tf.one_hot(next_token_indices, depth=self.vocab_size)

            # Stack all the single-step logits into a final tensor
            # Shape: [batch_size, sequence_length, 90]
            final_logits = tf.stack(all_logits, axis=1)
            return final_logits


    # --- 2. The "Teaching" Endpoint: Decorated for Training/Reconstruction ---
    # This will be one of the functions available in your saved model.
    @tf.function(input_signature=[
        tf.TensorSpec(shape=[None, LATENT_DIM], name="z"),
        tf.TensorSpec(shape=[], dtype=tf.int32, name="sequence_length"),
        tf.TensorSpec(shape=[None, None, VOCAB_SIZE], name="inputs")
    ])
    def reconstruct(self, z, sequence_length, inputs):
        """Runs the model in teacher-forcing mode."""
        return self.call(z, sequence_length, inputs=inputs)

    # --- 3. The "Improvising" Endpoint: Decorated for Generation ---
    # This will be another function available in your saved model.
    @tf.function(input_signature=[
        tf.TensorSpec(shape=[None, LATENT_DIM], name="z"),
        tf.TensorSpec(shape=[], dtype=tf.int32, name="sequence_length")
    ])
    def generate(self, z, sequence_length):
        """Runs the model in autoregressive generation mode."""
        return self.call(z, sequence_length, inputs=None)
   
def load_magenta_weights(decoder_model, checkpoint_path):
    """
    Loads weights from a TF1 Magenta checkpoint into a TF2 Keras decoder model.
    (Final corrected version)
    """
    reader = tf.train.load_checkpoint(checkpoint_path)

    # --- 1. Load z_to_initial_state weights ---
    z_kernel = reader.get_tensor("decoder/z_to_initial_state/kernel")
    z_bias = reader.get_tensor("decoder/z_to_initial_state/bias")
    decoder_model.z_to_initial_state.set_weights([z_kernel, z_bias])
    print("Loaded weights for 'z_to_initial_state' layer.")

    # --- 2. Load LSTM cell weights ---
    for i, cell in enumerate(decoder_model.lstm_cells):
        tf1_kernel_name = f"decoder/multi_rnn_cell/cell_{i}/lstm_cell/kernel"
        tf1_bias_name = f"decoder/multi_rnn_cell/cell_{i}/lstm_cell/bias"
        
        tf1_kernel = reader.get_tensor(tf1_kernel_name)
        tf1_bias = reader.get_tensor(tf1_bias_name)

        # THE FIX: Use the correct input dimension for splitting the kernel,
        # based on the layer index.
        if i == 0:
            # The original model's first layer has a complex input of dim 602.
            input_dim = 602
        else:
            # Subsequent layers take the output of the previous LSTM layer.
            input_dim = cell.units # which is 2048

        # Perform the split at the correct index.
        keras_kernel = tf1_kernel[:input_dim, :]
        keras_recurrent_kernel = tf1_kernel[input_dim:, :]
        
        # Now the shapes will match perfectly.
        cell.set_weights([keras_kernel, keras_recurrent_kernel, tf1_bias])
        print(f"Loaded weights for LSTM cell {i} from '{tf1_kernel_name}'.")

    # --- 3. Load output_projection weights ---
    out_kernel = reader.get_tensor("decoder/output_projection/kernel")
    out_bias = reader.get_tensor("decoder/output_projection/bias")
    decoder_model.output_projection.set_weights([out_kernel, out_bias])
    print("Loaded weights for 'output_projection' layer.")

    print("\nSuccessfully loaded all decoder weights from Magenta checkpoint!")




# --- Example Usage ---

# --- Configuration ---
MODEL_NAME = "mel_2bar_big"
CHECKPOINT_DIR = "models/download.magenta.tensorflow.org/models/music_vae/checkpoints"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, f"{MODEL_NAME}.ckpt")

# --- Correct parameters for 'mel_2bar_big' ---









In [None]:


# --- Define constants ---
LATENT_DIM = 512
OUTPUT_DEPTH = 90 # This is your OUTPUT_DEPTH
LSTM_UNITS = 2048
NUM_LAYERS = 3

# 2. Instantiate our TF2 Keras decoder model with the correct number of layers.
decoder = MusicVAEDecoder(
    output_depth=OUTPUT_DEPTH,
    lstm_units=LSTM_UNITS,
    num_layers=NUM_LAYERS
)
print("Model instance created.")

# --- MANUALLY BUILD EACH LAYER WITH THE CORRECT INPUT SHAPE ---
print("\nManually building model layers...")

# 1. Build the initial dense layer. It takes `z` as input.
decoder.z_to_initial_state.build(input_shape=(None, LATENT_DIM))
print(f"Built 'z_to_initial_state' layer.")

# 2. Build the LSTM cells. This is the most critical part.
# The input to the *first* LSTM cell is the concatenation of the previous step's output (90) and z (512).
first_lstm_input_dim = OUTPUT_DEPTH + LATENT_DIM # 90 + 512 = 602
decoder.lstm_cells[0].build(input_shape=(None, first_lstm_input_dim))
print(f"Built LSTM cell 0 with input dimension {first_lstm_input_dim}.")

# The input to subsequent LSTM cells is the output of the previous cell.
for i in range(1, len(decoder.lstm_cells)):
    prev_cell_output_dim = decoder.lstm_cells[i-1].units
    decoder.lstm_cells[i].build(input_shape=(None, prev_cell_output_dim))
    print(f"Built LSTM cell {i} with input dimension {prev_cell_output_dim}.")

# 3. Build the final output projection layer. It takes the output of the last LSTM cell.
last_lstm_output_dim = decoder.lstm_cells[-1].units
decoder.output_projection.build(input_shape=(None, last_lstm_output_dim))
print(f"Built 'output_projection' layer.")

print("\nAll layers built successfully.")
# 4. Finally, indicate that the model is built.
decoder.built=True

# --- Now, loading the weights will work ---
print("\nLoading Magenta weights into the built model...")
load_magenta_weights(decoder, CHECKPOINT_PATH) # This should now succeed
print("Weights loaded successfully.")

# This forces the creation of the `FuncGraph(name=reconstruct)`
concrete_reconstruct = decoder.reconstruct.get_concrete_function()

# This forces the creation of the `FuncGraph(name=generate)`
concrete_generate = decoder.generate.get_concrete_function()

print("Concrete functions created successfully.")

# --- Now, you can save the full model with signatures ---
# Define the path for the saved model directory
model_save_path = "models/music_vae_decoder_keras" 
os.makedirs(model_save_path, exist_ok=True)
decoder.save(model_save_path, signatures={
    'reconstruct': concrete_reconstruct,
    'generate': concrete_generate
},save_format="tf")
print(f"\nModel saved successfully to '{model_save_path}'.")


Model instance created.

Manually building model layers...
Built 'z_to_initial_state' layer.
Built LSTM cell 0 with input dimension 602.
Built LSTM cell 1 with input dimension 2048.
Built LSTM cell 2 with input dimension 2048.
Built 'output_projection' layer.

All layers built successfully.

Loading Magenta weights into the built model...
Loaded weights for 'z_to_initial_state' layer.
Loaded weights for LSTM cell 0 from 'decoder/multi_rnn_cell/cell_0/lstm_cell/kernel'.
Loaded weights for LSTM cell 1 from 'decoder/multi_rnn_cell/cell_1/lstm_cell/kernel'.
Loaded weights for LSTM cell 2 from 'decoder/multi_rnn_cell/cell_2/lstm_cell/kernel'.
Loaded weights for 'output_projection' layer.

Successfully loaded all decoder weights from Magenta checkpoint!
Weights loaded successfully.


InaccessibleTensorError: in user code:

    File "C:\Users\thoma\AppData\Local\Temp\ipykernel_3540\1994342386.py", line 154, in generate  *
        return self.call(z, sequence_length, inputs=None)
    File "C:\Users\thoma\AppData\Local\Temp\ipykernel_3540\3144523106.py", line 131, in call  *
        final_logits = tf.stack(all_logits, axis=1)

    InaccessibleTensorError: <tf.Tensor 'output_projection/BiasAdd:0' shape=(?, 90) dtype=float32> is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it.
    Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.
    
    <tf.Tensor 'output_projection/BiasAdd:0' shape=(?, 90) dtype=float32> was defined here:
        File "D:\Program Files\Python310\lib\runpy.py", line 196, in _run_module_as_main
          return _run_code(code, main_globals, None,
        File "D:\Program Files\Python310\lib\runpy.py", line 86, in _run_code
          exec(code, run_globals)
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\ipykernel_launcher.py", line 18, in <module>
          app.launch_new_instance()
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\traitlets\config\application.py", line 1075, in launch_instance
          app.start()
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\ipykernel\kernelapp.py", line 758, in start
          self.io_loop.start()
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\tornado\platform\asyncio.py", line 211, in start
          self.asyncio_loop.run_forever()
        File "D:\Program Files\Python310\lib\asyncio\base_events.py", line 595, in run_forever
          self._run_once()
        File "D:\Program Files\Python310\lib\asyncio\base_events.py", line 1881, in _run_once
          handle._run()
        File "D:\Program Files\Python310\lib\asyncio\events.py", line 80, in _run
          self._context.run(self._callback, *self._args)
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\ipykernel\utils.py", line 71, in preserve_context
          return await f(*args, **kwargs)
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\ipykernel\kernelbase.py", line 614, in shell_main
          await self.dispatch_shell(msg, subshell_id=subshell_id)
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\ipykernel\kernelbase.py", line 471, in dispatch_shell
          await result
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\ipykernel\ipkernel.py", line 366, in execute_request
          await super().execute_request(stream, ident, parent)
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\ipykernel\kernelbase.py", line 827, in execute_request
          reply_content = await reply_content
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\ipykernel\ipkernel.py", line 458, in do_execute
          res = shell.run_cell(
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\ipykernel\zmqshell.py", line 663, in run_cell
          return super().run_cell(*args, **kwargs)
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\IPython\core\interactiveshell.py", line 3077, in run_cell
          result = self._run_cell(
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\IPython\core\interactiveshell.py", line 3132, in _run_cell
          result = runner(coro)
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\IPython\core\async_helpers.py", line 128, in _pseudo_sync_runner
          coro.send(None)
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\IPython\core\interactiveshell.py", line 3336, in run_cell_async
          has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\IPython\core\interactiveshell.py", line 3519, in run_ast_nodes
          if await self.run_code(code, result, async_=asy):
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\IPython\core\interactiveshell.py", line 3579, in run_code
          exec(code_obj, self.user_global_ns, self.user_ns)
        File "C:\Users\thoma\AppData\Local\Temp\ipykernel_3540\1239954717.py", line 54, in <module>
          decoder.save(model_save_path, signatures={
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
          return fn(*args, **kwargs)
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\keras\engine\training.py", line 2826, in save
          saving_api.save_model(
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\keras\saving\saving_api.py", line 145, in save_model
          return legacy_sm_saving_lib.save_model(
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
          return fn(*args, **kwargs)
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\keras\saving\legacy\save.py", line 168, in save_model
          saved_model_save.save(
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\keras\saving\legacy\saved_model\save.py", line 98, in save
          saved_nodes, node_paths = save_lib.save_and_return_nodes(
        File "C:\Users\thoma\AppData\Local\Temp\ipykernel_3540\1994342386.py", line 154, in generate
          return self.call(z, sequence_length, inputs=None)
        File "C:\Users\thoma\AppData\Local\Temp\ipykernel_3540\3144523106.py", line 66, in call
          if inputs is not None:
        File "C:\Users\thoma\AppData\Local\Temp\ipykernel_3540\3144523106.py", line 95, in call
          for _ in range(sequence_length):
        File "C:\Users\thoma\AppData\Local\Temp\ipykernel_3540\3144523106.py", line 120, in call
          step_logits = self.output_projection(final_cell_output)
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\keras\engine\base_layer_v1.py", line 838, in __call__
          outputs = call_fn(cast_inputs, *args, **kwargs)
        File "d:\Users\thoma\Documents\git\python_midi_training\.venv\lib\site-packages\keras\layers\core\dense.py", line 252, in call
          outputs = tf.nn.bias_add(outputs, self.bias)
    
    The tensor <tf.Tensor 'output_projection/BiasAdd:0' shape=(?, 90) dtype=float32> cannot be accessed from FuncGraph(name=generate, id=1272522258656), because it was defined in FuncGraph(name=tmp, id=1272522417984), which is out of scope.


In [None]:
import tensorflow as tf
import numpy as np
import os
import requests

LATENT_DIM = 512

# IMPORTANT: We must use TF1 compatibility mode to load and run the original Magenta model.
# This needs to be at the very top of your script.
import tensorflow.compat.v1 as tf1
tf1.disable_v2_behavior()


from magenta.models.music_vae import configs
from magenta.models.music_vae.trained_model import TrainedModel # We need this class

# Use tensorflow.compat.v1 and disable V2 behavior for the original model




# ==============================================================================
# 1. SETUP & MODEL LOADING
# ==============================================================================

print("--- Step 1: Loading original TF1-style MusicVAE model ---")
mel_2bar_config = configs.CONFIG_MAP['cat-mel_2bar_big']
BASE_DIR = "models/download.magenta.tensorflow.org/models/music_vae"
checkpoint_path = BASE_DIR + '/checkpoints/mel_2bar_big.ckpt'

# Use a batch size of 1 for easier comparison
LATENT_DIM = mel_2bar_config.hparams.z_size
SEQUENCE_LENGTH = 32 # Define the desired generation length
BATCH_SIZE = 1
VOCAB_SIZE = 90
mel_2bar = TrainedModel(mel_2bar_config, batch_size=BATCH_SIZE, checkpoint_dir_or_path=checkpoint_path)
print("Original model loaded.")

graph = mel_2bar._sess.graph

# --- Use the exact names discovered from your debugging ---
# The z placeholder with shape (1, 512)
Z_PLACEHOLDER_NAME = 'Placeholder_1:0'

# The output logits tensor from the sampling graph
LOGITS_TENSOR_NAME = 'sample/decoder/rnn_output:0'



model_blueprint = mel_2bar_config.model
decoder_blueprint = model_blueprint.decoder

    # Retrive the relevant elements of the graph
temperature_placeholder = graph.get_tensor_by_name('Placeholder:0')
z_placeholder = graph.get_tensor_by_name('Placeholder_1:0')
inputs_placeholder = graph.get_tensor_by_name('Placeholder_2:0')
controls_placeholder = graph.get_tensor_by_name('Placeholder_3:0')
inputs_length_placeholder = graph.get_tensor_by_name('Placeholder_4:0')
output_length_placeholder = graph.get_tensor_by_name('Placeholder_5:0') # The final placeholder
logits_tensor = graph.get_tensor_by_name('decoder/TensorArrayStack_1/TensorArrayGatherV3:0')


np.random.seed(42)
z_np = np.random.randn(BATCH_SIZE, LATENT_DIM).astype(np.float32)

# Dummy inputs to satisfy the graph's requirements, based on your debugging
dummy_inputs = np.zeros((BATCH_SIZE, SEQUENCE_LENGTH, VOCAB_SIZE), dtype=np.float32)
dummy_inputs_length = np.array([SEQUENCE_LENGTH] * BATCH_SIZE, dtype=np.int32)
dummy_controls = np.zeros((BATCH_SIZE, SEQUENCE_LENGTH, 0), dtype=np.float32)

# Construct the full, correct feed dictionary
feed_dict = {
    temperature_placeholder: 0, # We don't need the temperature here other than as a dummy
    z_placeholder: z_np,
    inputs_placeholder: dummy_inputs,
    inputs_length_placeholder: dummy_inputs_length,
    controls_placeholder: dummy_controls,
    output_length_placeholder: SEQUENCE_LENGTH # The final missing piece
}



logits_tf1 = mel_2bar._sess.run(
    logits_tensor,
    feed_dict=feed_dict
)

print("\nLogits for the very first step (first 5 values):")
print(logits_tf1[2, 0, :5])




# --- Main Comparison Logic ---






  




--- Step 1: Loading original TF1-style MusicVAE model ---
INFO:tensorflow:Building MusicVAE model with BidirectionalLstmEncoder, CategoricalLstmDecoder, and hparams:
{'max_seq_len': 32, 'z_size': 512, 'free_bits': 0, 'max_beta': 0.5, 'beta_rate': 0.99999, 'batch_size': 1, 'grad_clip': 1.0, 'clip_mode': 'global_norm', 'grad_norm_clip_to_zero': 10000, 'learning_rate': 0.001, 'decay_rate': 0.9999, 'min_learning_rate': 1e-05, 'conditional': True, 'dec_rnn_size': [2048, 2048, 2048], 'enc_rnn_size': [2048], 'dropout_keep_prob': 1.0, 'sampling_schedule': 'inverse_sigmoid', 'sampling_rate': 1000, 'use_cudnn': False, 'residual_encoder': False, 'residual_decoder': False, 'control_preprocessing_rnn_size': [256]}
INFO:tensorflow:
Encoder Cells (bidirectional):
  units: [2048]

INFO:tensorflow:
Decoder Cells:
  units: [2048, 2048, 2048]

INFO:tensorflow:Restoring parameters from models/download.magenta.tensorflow.org/models/music_vae/checkpoints/mel_2bar_big.ckpt
Original model loaded.

Logits for 

In [12]:
generated_sequence = decoder(z_np, sequence_length=SEQUENCE_LENGTH)
print(generated_sequence[0, 0, :5][:10])

Tensor("strided_slice_2:0", shape=(5,), dtype=float32)


In [None]:
def inspect_checkpoint(checkpoint_path):
    """
    A helper function to print all variable names and their shapes in a checkpoint.
    This is extremely useful for debugging name-related errors.
    """
    print(f"\n--- Inspecting variables in checkpoint: {checkpoint_path} ---")
    try:
        reader = tf.train.load_checkpoint(checkpoint_path)
        shape_map = reader.get_variable_to_shape_map()
        for key in sorted(shape_map.keys()):
            print(f"Tensor name: {key}, shape: {shape_map[key]}")
    except Exception as e:
        print(f"Could not read checkpoint: {e}")
    print("--------------------------------------------------------\n")

MODEL_NAME = "mel_2bar_big"
CHECKPOINT_DIR = "models/download.magenta.tensorflow.org/models/music_vae/checkpoints"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, f"{MODEL_NAME}.ckpt")

inspect_checkpoint(CHECKPOINT_PATH)


def get_tensor_names_from_graph(graph):
    """
    A helper function to print all tensor names in a TensorFlow graph.
    This helps identify the correct names to use when accessing tensors.
    """
    print("\n--- Inspecting tensors in the graph ---")
    for index in range(len(graph.get_operations())):
        op = graph.get_operations()[index]
        print(f"Operation name: {op.name}"+"\n")
        for tensor in op.values():
            print(f"Tensor name: {tensor.name}, shape: {tensor.shape}"+"\n")
    print("--------------------------------------------------------\n")






