In [13]:

import os
import requests
import tarfile
import io

# --- Step 1: Download the Pre-trained MusicVAE Model with Chunking ---

# Define the model name and the URL
# Note: Based on the URL, the model appears to be 'mel_2bar_big', not 'cat-mel_2bar_big'
MODEL_NAME = 'mel_2bar_big' 
MODEL_URL = 'http://download.magenta.tensorflow.org/models/music_vae/checkpoints_bundled/mel_2bar_big.ckpt.tar'

MODEL_DIR = 'models' # Directory to save the model
CHECKPOINT_SUBDIR = os.path.join('download.magenta.tensorflow.org', 'models', 'music_vae', 'checkpoints')
CHECKPOINT_BASE_NAME = os.path.join(MODEL_DIR, CHECKPOINT_SUBDIR, MODEL_NAME + '.ckpt') # The base name for the extracted checkpoint files
CHECKPOINT_FILE_TO_CHECK = CHECKPOINT_BASE_NAME + '.index' # We check for the .index file as an indicator of successful extraction

# Create the directory if it doesn't exist
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)

# Check if the model is already downloaded
if not os.path.exists(CHECKPOINT_FILE_TO_CHECK):
    print(f'Downloading pre-trained model: {MODEL_NAME}...')
    print(f'From URL: {MODEL_URL}')
    
    try:
        # Use a session for better connection management
        with requests.Session() as session:
            # stream=True is crucial to avoid loading the whole file into memory
            response = session.get(MODEL_URL, stream=True)
            
            # Raise an exception if the download fails
            response.raise_for_status()

            # Open the tar archive for reading from a stream ('r|')
            # The fileobj is the raw response body, which is a file-like object
            with tarfile.open(fileobj=response.raw, mode='r|*') as tar:
                print(f'Starting extraction to: {MODEL_DIR}')
                tar.extractall(path=MODEL_DIR)
        
        print(f'Successfully downloaded and extracted model to directory: {MODEL_DIR}')
        # We check for the .ckpt.index file which is part of the actual output.
        if os.path.exists(CHECKPOINT_FILE_TO_CHECK):
             print(f'Verified checkpoint file exists at: {CHECKPOINT_FILE_TO_CHECK}')
        else:
             print(f'Warning: Expected checkpoint index file not found at {CHECKPOINT_FILE_TO_CHECK}. Please check the archive contents.')


    except requests.exceptions.RequestException as e:
        print(f'Error: Could not download model. An error occurred: {e}')
    except tarfile.TarError as e:
        print(f'Error: Could not extract the tar file. It may be corrupted or in an unexpected format. {e}')
    except Exception as e:
        print(f'An unexpected error occurred: {e}')

else:
    print(f'Model checkpoint index file already exists at: {CHECKPOINT_FILE_TO_CHECK}')


Model checkpoint index file already exists at: models\download.magenta.tensorflow.org\models\music_vae\checkpoints\mel_2bar_big.ckpt.index


In [14]:

import note_seq

def triplets_to_note_sequence(triplets, qpm=120):
    """
    Converts a list of (midi_note, onset_time, duration_time) triplets
    into a note_seq.NoteSequence object.
    Args:
        triplets: A list of tuples, where each tuple is (midi_note, onset_time, duration_time).
                  midi_note: MIDI pitch (0-127).
                  onset_time: Start time of the note in seconds.
                  duration_time: Duration of the note in seconds.
        qpm: Quarter notes per minute for the NoteSequence tempo.
    Returns:
        A note_seq.NoteSequence object.
    """
    note_sequence = note_seq.NoteSequence()
    note_sequence.tempos.add().qpm = qpm
    max_end_time = 0.0
    for midi_note, onset_time, duration_time in triplets:
        note = note_sequence.notes.add()
        note.pitch = midi_note
        note.start_time = onset_time
        note.end_time = onset_time + duration_time
        note.velocity = 100  # Default velocity
        max_end_time = max(max_end_time, note.end_time)
    note_sequence.total_time = max_end_time
    return note_sequence

print("Function `triplets_to_note_sequence` defined.")


Function `triplets_to_note_sequence` defined.


In [15]:
# Example Usage:
sample_triplets = [
    (60, 0.0, 0.5),  # C4, start at 0s, duration 0.5s
    (62, 0.5, 0.5),  # D4, start at 0.5s, duration 0.5s
    (64, 1.0, 0.5),  # E4, start at 1.0s, duration 0.5s
    (65, 1.5, 0.5),  # F4, start at 1.5s, duration 0.5s
    (67, 2.0, 0.5),  # G4, start at 2.0s, duration 0.5s
    (69, 2.5, 0.5),  # A4, start at 2.5s, duration 0.5s
    (71, 3.0, 0.5),  # B4, start at 3.0s, duration 0.5s
    (72, 3.5, 0.5)   # C5, start at 3.5s, duration 0.5s
]
sample_ns = triplets_to_note_sequence(sample_triplets)
print(f"Sample NoteSequence created with {len(sample_ns.notes)} notes and total time {sample_ns.total_time}s.")

Sample NoteSequence created with 8 notes and total time 4.0s.


In [21]:
print('Importing libraries and loading the trained model')
import magenta.music as mm
from magenta.models.music_vae import configs
from magenta.models.music_vae.trained_model import TrainedModel
import numpy as np
import os
import tensorflow.compat.v1 as tf
import random

tf.disable_v2_behavior()

# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
tf.set_random_seed(SEED)
random.seed(SEED)

mel_2bar_config = configs.CONFIG_MAP['cat-mel_2bar_big']

BASE_DIR="models/download.magenta.tensorflow.org/models/music_vae"
mel_2bar = TrainedModel(mel_2bar_config, batch_size=4, checkpoint_dir_or_path=BASE_DIR + '/checkpoints/mel_2bar_big.ckpt')

Importing libraries and loading the trained model
INFO:tensorflow:Building MusicVAE model with BidirectionalLstmEncoder, CategoricalLstmDecoder, and hparams:
{'max_seq_len': 32, 'z_size': 512, 'free_bits': 0, 'max_beta': 0.5, 'beta_rate': 0.99999, 'batch_size': 4, 'grad_clip': 1.0, 'clip_mode': 'global_norm', 'grad_norm_clip_to_zero': 10000, 'learning_rate': 0.001, 'decay_rate': 0.9999, 'min_learning_rate': 1e-05, 'conditional': True, 'dec_rnn_size': [2048, 2048, 2048], 'enc_rnn_size': [2048], 'dropout_keep_prob': 1.0, 'sampling_schedule': 'inverse_sigmoid', 'sampling_rate': 1000, 'use_cudnn': False, 'residual_encoder': False, 'residual_decoder': False, 'control_preprocessing_rnn_size': [256]}
INFO:tensorflow:
Encoder Cells (bidirectional):
  units: [2048]

INFO:tensorflow:
Encoder Cells (bidirectional):
  units: [2048]

INFO:tensorflow:
Decoder Cells:
  units: [2048, 2048, 2048]

INFO:tensorflow:
Decoder Cells:
  units: [2048, 2048, 2048]

INFO:tensorflow:Restoring parameters from mod

In [38]:
sampled, embedding, sd_embedding = mel_2bar.encode([sample_ns])
print(embedding)

[[ 1.90482009e-02  4.82765539e-03 -1.34119764e-03 -4.43474483e-03
  -4.68080584e-03  1.08628422e-02  8.87716189e-03 -1.56116914e-02
  -1.33931041e-02 -2.66647413e-02 -5.86027792e-03 -1.26132015e-02
   1.22770679e+00  1.88999018e-03 -3.53541598e-02  2.01731268e-02
   6.10288046e-03 -8.64286069e-03  1.07047614e-03 -1.33629395e-02
   7.65448529e-03  6.78264629e-03 -1.79836676e-02 -4.57583368e-03
   1.50993037e+00 -2.10679360e-02 -3.46246548e-03 -8.62141140e-03
  -2.06143595e-02 -1.63760744e-02 -3.46385036e-03  2.07197443e-02
  -3.63099854e-03 -2.10210495e-03 -1.74172223e-03  1.22285504e-02
  -1.20626673e-01 -1.06712142e-02  1.58639960e-02  2.21034940e-02
   3.11021879e-03  5.89989685e-02 -1.18029444e-02 -7.19258934e-03
  -4.03045211e-03  2.40522288e-02 -2.61870865e-03 -5.65937301e-03
   7.41415890e-04 -5.82466833e-03  1.52735226e-02  1.22787952e-02
   1.16828494e-02  3.93337719e-02 -2.58544134e-03 -6.48694113e-04
   1.54965185e-03 -8.69640522e-03 -1.49324769e-02 -1.43191330e-02
  -1.36440

In [18]:
embedding = mel_2bar.encode([sample_ns])
print(embedding)

(array([[-0.32859844,  2.5385134 , -0.14845121, -2.1416059 ,  0.2500711 ,
         0.9860438 ,  0.01311164,  1.0558989 , -2.6675174 ,  0.48967314,
        -0.46685758, -0.3424712 ,  1.2697699 , -0.49910986,  0.5691808 ,
        -0.64482987,  1.5787976 , -1.7881442 , -0.94753236,  0.46320072,
         0.7382736 , -1.8475722 , -1.7450935 , -0.34709972,  1.6145573 ,
        -0.3899495 ,  0.26486018,  0.31681734, -0.09062053,  0.8023599 ,
        -1.5752419 ,  0.23768356,  0.17349218, -0.20907716, -1.1703043 ,
        -0.3110988 ,  0.8995861 , -0.4059949 ,  0.31302467, -0.60116255,
         0.9392219 ,  0.29153174, -0.21099712,  0.48803258, -0.9972728 ,
         0.90274644,  0.95580745,  0.98795533, -0.42667934, -0.87792856,
         0.02920133,  1.5711143 , -0.7149434 ,  0.5566256 , -0.0250742 ,
        -0.9941577 ,  1.1463552 , -0.43722332, -0.84510696,  0.37793   ,
        -1.309641  , -0.03184675, -0.72319746,  0.10692144,  0.45270488,
        -0.36779246,  0.5196365 ,  1.1493127 ,  0.