In [1]:
%reload_ext autoreload
%autoreload 2

from model_def.decoder import MusicVAEDecoder, LATENT_DIM, OUTPUT_DEPTH, SEQUENCE_LENGTH, LSTM_UNITS, NUM_LAYERS












In [None]:



# 2. Instantiate our TF2 Keras decoder model with the correct number of layers.
decoder = MusicVAEDecoder(
    output_depth=OUTPUT_DEPTH,
    lstm_units=LSTM_UNITS,
    num_layers=NUM_LAYERS,
    sequence_length=SEQUENCE_LENGTH
)
print("Model instance created.")

# --- MANUALLY BUILD EACH LAYER WITH THE CORRECT INPUT SHAPE ---
print("\nManually building model layers...")

# 1. Build the initial dense layer. It takes `z` as input.
decoder.z_to_initial_state.build(input_shape=(None, LATENT_DIM))
print(f"Built 'z_to_initial_state' layer.")

# 2. Build the LSTM cells. This is the most critical part.
# The input to the *first* LSTM cell is the concatenation of the previous step's output (90) and z (512).
first_lstm_input_dim = OUTPUT_DEPTH + LATENT_DIM # 90 + 512 = 602
decoder.lstm_cells[0].build(input_shape=(None, first_lstm_input_dim))
print(f"Built LSTM cell 0 with input dimension {first_lstm_input_dim}.")

# The input to subsequent LSTM cells is the output of the previous cell.
for i in range(1, len(decoder.lstm_cells)):
    prev_cell_output_dim = decoder.lstm_cells[i-1].units
    decoder.lstm_cells[i].build(input_shape=(None, prev_cell_output_dim))
    print(f"Built LSTM cell {i} with input dimension {prev_cell_output_dim}.")

# 3. Build the final output projection layer. It takes the output of the last LSTM cell.
last_lstm_output_dim = decoder.lstm_cells[-1].units
decoder.output_projection.build(input_shape=(None, last_lstm_output_dim))
print(f"Built 'output_projection' layer.")

print("\nAll layers built successfully.")
# 4. Finally, indicate that the model is built.
_ = decoder(z=tf.zeros((1, LATENT_DIM)))




Model instance created.

Manually building model layers...
Built 'z_to_initial_state' layer.
Built LSTM cell 0 with input dimension 602.
Built LSTM cell 1 with input dimension 2048.
Built LSTM cell 2 with input dimension 2048.
Built 'output_projection' layer.

All layers built successfully.


In [3]:
# --- Now, loading the weights will work ---
print("\nLoading Magenta weights into the built model...")

MODEL_NAME = "mel_2bar_big"
CHECKPOINT_DIR = "models/download.magenta.tensorflow.org/models/music_vae/checkpoints"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, f"{MODEL_NAME}.ckpt")


load_magenta_weights(decoder, CHECKPOINT_PATH) # This should now succeed
print("Weights loaded successfully.")



Loading Magenta weights into the built model...
Loaded weights for 'z_to_initial_state' layer.
Loaded weights for LSTM cell 0 from 'decoder/multi_rnn_cell/cell_0/lstm_cell/kernel'.
Loaded weights for LSTM cell 1 from 'decoder/multi_rnn_cell/cell_1/lstm_cell/kernel'.
Loaded weights for LSTM cell 2 from 'decoder/multi_rnn_cell/cell_2/lstm_cell/kernel'.
Loaded weights for 'output_projection' layer.

Successfully loaded all decoder weights from Magenta checkpoint!
Weights loaded successfully.


In [4]:
# This forces the creation of the `FuncGraph(name=reconstruct)`
concrete_reconstruct = decoder.reconstruct.get_concrete_function()

# This forces the creation of the `FuncGraph(name=generate)`
concrete_generate = decoder.generate.get_concrete_function()

print("Concrete functions created successfully.")

# --- Now, you can save the full model with signatures ---
# Define the path for the saved model directory
model_save_path = "models/music_vae_decoder_keras" 
os.makedirs(model_save_path, exist_ok=True)
decoder.save(model_save_path, signatures={
    'reconstruct': concrete_reconstruct,
    'generate': concrete_generate
},save_format="tf")
print(f"\nModel saved successfully to '{model_save_path}'.")


Concrete functions created successfully.
INFO:tensorflow:Assets written to: models/music_vae_decoder_keras\assets

Model saved successfully to 'models/music_vae_decoder_keras'.


--- Step 1: Loading original TF1-style MusicVAE model ---
INFO:tensorflow:Building MusicVAE model with BidirectionalLstmEncoder, CategoricalLstmDecoder, and hparams:
{'max_seq_len': 32, 'z_size': 512, 'free_bits': 0, 'max_beta': 0.5, 'beta_rate': 0.99999, 'batch_size': 1, 'grad_clip': 1.0, 'clip_mode': 'global_norm', 'grad_norm_clip_to_zero': 10000, 'learning_rate': 0.001, 'decay_rate': 0.9999, 'min_learning_rate': 1e-05, 'conditional': True, 'dec_rnn_size': [2048, 2048, 2048], 'enc_rnn_size': [2048], 'dropout_keep_prob': 1.0, 'sampling_schedule': 'inverse_sigmoid', 'sampling_rate': 1000, 'use_cudnn': False, 'residual_encoder': False, 'residual_decoder': False, 'control_preprocessing_rnn_size': [256]}
INFO:tensorflow:
Encoder Cells (bidirectional):
  units: [2048]

INFO:tensorflow:
Decoder Cells:
  units: [2048, 2048, 2048]

Instructions for updating:
Use `tf.cast` instead.


  tf.layers.dense(
  self._kernel = self.add_variable(
  self._bias = self.add_variable(


Instructions for updating:
Please use `keras.layers.Bidirectional(keras.layers.RNN(cell))`, which is equivalent to this API
Instructions for updating:
Please use `keras.layers.RNN(cell)`, which is equivalent to this API
Instructions for updating:
`scale_identity_multiplier` is deprecated; please combine it into `scale_diag` directly instead.


  mu = tf.layers.dense(
  sigma = tf.layers.dense(


INFO:tensorflow:Restoring parameters from models/download.magenta.tensorflow.org/models/music_vae/checkpoints/mel_2bar_big.ckpt
Original model loaded.


KeyError: "The name 'decoder/TensorArrayStack_1/TensorArrayGatherV3:0' refers to a Tensor which does not exist. The operation, 'decoder/TensorArrayStack_1/TensorArrayGatherV3', does not exist in the graph."

NameError: name 'decoder' is not defined

In [9]:
def inspect_checkpoint(checkpoint_path):
    """
    A helper function to print all variable names and their shapes in a checkpoint.
    This is extremely useful for debugging name-related errors.
    """
    print(f"\n--- Inspecting variables in checkpoint: {checkpoint_path} ---")
    try:
        reader = tf.train.load_checkpoint(checkpoint_path)
        shape_map = reader.get_variable_to_shape_map()
        for key in sorted(shape_map.keys()):
            print(f"Tensor name: {key}, shape: {shape_map[key]}")
    except Exception as e:
        print(f"Could not read checkpoint: {e}")
    print("--------------------------------------------------------\n")

MODEL_NAME = "mel_2bar_big"
CHECKPOINT_DIR = "models/download.magenta.tensorflow.org/models/music_vae/checkpoints"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, f"{MODEL_NAME}.ckpt")

inspect_checkpoint(CHECKPOINT_PATH)


def get_tensor_names_from_graph(graph):
    """
    A helper function to print all tensor names in a TensorFlow graph.
    This helps identify the correct names to use when accessing tensors.
    """
    print("\n--- Inspecting tensors in the graph ---")
    for index in range(len(graph.get_operations())):
        op = graph.get_operations()[index]
        print(f"Operation name: {op.name}"+"\n")
        for tensor in op.values():
            print(f"Tensor name: {tensor.name}, shape: {tensor.shape}"+"\n")
    print("--------------------------------------------------------\n")


get_tensor_names_from_graph(tf.compat.v1.get_default_graph())






--- Inspecting variables in checkpoint: models/download.magenta.tensorflow.org/models/music_vae/checkpoints\mel_2bar_big.ckpt ---
Tensor name: beta1_power, shape: []
Tensor name: beta2_power, shape: []
Tensor name: decoder/multi_rnn_cell/cell_0/lstm_cell/bias, shape: [8192]
Tensor name: decoder/multi_rnn_cell/cell_0/lstm_cell/bias/Adam, shape: [8192]
Tensor name: decoder/multi_rnn_cell/cell_0/lstm_cell/bias/Adam_1, shape: [8192]
Tensor name: decoder/multi_rnn_cell/cell_0/lstm_cell/kernel, shape: [2650, 8192]
Tensor name: decoder/multi_rnn_cell/cell_0/lstm_cell/kernel/Adam, shape: [2650, 8192]
Tensor name: decoder/multi_rnn_cell/cell_0/lstm_cell/kernel/Adam_1, shape: [2650, 8192]
Tensor name: decoder/multi_rnn_cell/cell_1/lstm_cell/bias, shape: [8192]
Tensor name: decoder/multi_rnn_cell/cell_1/lstm_cell/bias/Adam, shape: [8192]
Tensor name: decoder/multi_rnn_cell/cell_1/lstm_cell/bias/Adam_1, shape: [8192]
Tensor name: decoder/multi_rnn_cell/cell_1/lstm_cell/kernel, shape: [4096, 8192]