In [57]:
%reload_ext autoreload
%autoreload 2

from model_def.decoder import MusicVAEDecoder









In [58]:
import tensorflow as tf


# 2. Instantiate our TF2 Keras decoder model with the correct number of layers.
decoder = MusicVAEDecoder()
print("Model instance created.")

# --- MANUALLY BUILD EACH LAYER WITH THE CORRECT INPUT SHAPE ---
print("\nManually building model layers...")

# 1. Build the initial dense layer. It takes `z` as input.
decoder.z_to_initial_state.build(input_shape=(None, decoder.get_config()["latent_dim"]))
print(f"Built 'z_to_initial_state' layer.")

# 2. Build the LSTM cells. This is the most critical part.
# The input to the *first* LSTM cell is the concatenation of the previous step's output (90) and z (512).
first_lstm_input_dim = decoder.get_config()["output_depth"]+decoder.get_config()["latent_dim"] # 90 + 512 = 602
decoder.lstm_cells[0].build(input_shape=(None, first_lstm_input_dim))
print(f"Built LSTM cell 0 with input dimension {first_lstm_input_dim}.")

# The input to subsequent LSTM cells is the output of the previous cell.
for i in range(1, len(decoder.lstm_cells)):
    prev_cell_output_dim = decoder.lstm_cells[i-1].units
    decoder.lstm_cells[i].build(input_shape=(None, prev_cell_output_dim))
    print(f"Built LSTM cell {i} with input dimension {prev_cell_output_dim}.")

# 3. Build the final output projection layer. It takes the output of the last LSTM cell.
last_lstm_output_dim = decoder.lstm_cells[-1].units
decoder.output_projection.build(input_shape=(None, last_lstm_output_dim))
print(f"Built 'output_projection' layer.")

print("\nAll layers built successfully.")

# Get the model's configuration
config = decoder.get_config()
batch_size = 1 # Dummy batch size for building

# Create dummy tensors with the correct shapes
dummy_z = tf.zeros((batch_size, config["latent_dim"]))
dummy_sequence = tf.zeros((batch_size, config["sequence_length"], config["output_depth"]))

# 4. Finally, Build the model by passing the inputs as a tuple
_ = decoder((dummy_z, dummy_sequence))






Model instance created.

Manually building model layers...
Built 'z_to_initial_state' layer.
Built LSTM cell 0 with input dimension 602.
Built LSTM cell 1 with input dimension 2048.
Built LSTM cell 2 with input dimension 2048.
Built 'output_projection' layer.

All layers built successfully.


In [59]:
from model_def.decoder import load_magenta_weights
import os

# --- Now, loading the weights will work ---
print("\nLoading Magenta weights into the built model...")

MODEL_NAME = "mel_2bar_big"
CHECKPOINT_DIR = "models/download.magenta.tensorflow.org/models/music_vae/checkpoints"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, f"{MODEL_NAME}.ckpt")


load_magenta_weights(decoder, CHECKPOINT_PATH) # This should now succeed
print("Weights loaded successfully.")



Loading Magenta weights into the built model...
Loaded weights for 'z_to_initial_state' layer.
Loaded weights for LSTM cell 0 from 'decoder/multi_rnn_cell/cell_0/lstm_cell/kernel'.
Loaded weights for LSTM cell 1 from 'decoder/multi_rnn_cell/cell_1/lstm_cell/kernel'.
Loaded weights for LSTM cell 2 from 'decoder/multi_rnn_cell/cell_2/lstm_cell/kernel'.
Loaded weights for 'output_projection' layer.

Successfully loaded all decoder weights from Magenta checkpoint!
Weights loaded successfully.


In [60]:
# Now, save the model with the correct weights and signatures.
print("\nCreating concrete functions for model signatures...")
# --- Create concrete functions for the signatures ---

# This forces the creation of the `FuncGraph(name=reconstruct)`
concrete_reconstruct = decoder.reconstruct.get_concrete_function()

# This forces the creation of the `FuncGraph(name=generate)`
concrete_generate = decoder.generate.get_concrete_function()

print("Concrete functions created successfully.")

# --- Now, you can save the full model with signatures ---
# Define the path for the saved model directory
model_save_path = "models/music_vae_decoder_keras" 
os.makedirs(model_save_path, exist_ok=True)
decoder.save(model_save_path, signatures={
    'reconstruct': concrete_reconstruct,
    'generate': concrete_generate
},save_format="tf")
print(f"\nModel saved successfully to '{model_save_path}'.")



Creating concrete functions for model signatures...
Concrete functions created successfully.




INFO:tensorflow:Assets written to: models/music_vae_decoder_keras\assets


INFO:tensorflow:Assets written to: models/music_vae_decoder_keras\assets



Model saved successfully to 'models/music_vae_decoder_keras'.
