In [None]:

import os
import requests
import tarfile
import io

# --- Step 1: Download the Pre-trained MusicVAE Model with Chunking ---

# Define the model name and the URL
# Note: Based on the URL, the model appears to be 'mel_2bar_big', not 'cat-mel_2bar_big'
MODEL_NAME = 'mel_2bar_big' 
MODEL_URL = 'http://download.magenta.tensorflow.org/models/music_vae/checkpoints_bundled/mel_2bar_big.ckpt.tar'

MODEL_DIR = 'models' # Directory to save the model
CHECKPOINT_SUBDIR = os.path.join('download.magenta.tensorflow.org', 'models', 'music_vae', 'checkpoints')
CHECKPOINT_BASE_NAME = os.path.join(MODEL_DIR, CHECKPOINT_SUBDIR, MODEL_NAME + '.ckpt') # The base name for the extracted checkpoint files
CHECKPOINT_FILE_TO_CHECK = CHECKPOINT_BASE_NAME + '.index' # We check for the .index file as an indicator of successful extraction

# Create the directory if it doesn't exist
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)

# Check if the model is already downloaded
if not os.path.exists(CHECKPOINT_FILE_TO_CHECK):
    print(f'Downloading pre-trained model: {MODEL_NAME}...')
    print(f'From URL: {MODEL_URL}')
    
    try:
        # Use a session for better connection management
        with requests.Session() as session:
            # stream=True is crucial to avoid loading the whole file into memory
            response = session.get(MODEL_URL, stream=True)
            
            # Raise an exception if the download fails
            response.raise_for_status()

            # Open the tar archive for reading from a stream ('r|')
            # The fileobj is the raw response body, which is a file-like object
            with tarfile.open(fileobj=response.raw, mode='r|*') as tar:
                print(f'Starting extraction to: {MODEL_DIR}')
                tar.extractall(path=MODEL_DIR)
        
        print(f'Successfully downloaded and extracted model to directory: {MODEL_DIR}')
        # We check for the .ckpt.index file which is part of the actual output.
        if os.path.exists(CHECKPOINT_FILE_TO_CHECK):
             print(f'Verified checkpoint file exists at: {CHECKPOINT_FILE_TO_CHECK}')
        else:
             print(f'Warning: Expected checkpoint index file not found at {CHECKPOINT_FILE_TO_CHECK}. Please check the archive contents.')


    except requests.exceptions.RequestException as e:
        print(f'Error: Could not download model. An error occurred: {e}')
    except tarfile.TarError as e:
        print(f'Error: Could not extract the tar file. It may be corrupted or in an unexpected format. {e}')
    except Exception as e:
        print(f'An unexpected error occurred: {e}')

else:
    print(f'Model checkpoint index file already exists at: {CHECKPOINT_FILE_TO_CHECK}')


In [None]:

import note_seq

def triplets_to_note_sequence(triplets, qpm=120):
    """
    Converts a list of (midi_note, onset_time, duration_time) triplets
    into a note_seq.NoteSequence object.
    Args:
        triplets: A list of tuples, where each tuple is (midi_note, onset_time, duration_time).
                  midi_note: MIDI pitch (0-127).
                  onset_time: Start time of the note in seconds.
                  duration_time: Duration of the note in seconds.
        qpm: Quarter notes per minute for the NoteSequence tempo.
    Returns:
        A note_seq.NoteSequence object.
    """
    note_sequence = note_seq.NoteSequence()
    note_sequence.tempos.add().qpm = qpm
    max_end_time = 0.0
    for midi_note, onset_time, duration_time in triplets:
        note = note_sequence.notes.add()
        note.pitch = midi_note
        note.start_time = onset_time
        note.end_time = onset_time + duration_time
        note.velocity = 100  # Default velocity
        max_end_time = max(max_end_time, note.end_time)
    note_sequence.total_time = max_end_time
    return note_sequence

print("Function `triplets_to_note_sequence` defined.")


In [None]:
# Example Usage:
sample_triplets = [
    (60, 0.0, 0.5),  # C4, start at 0s, duration 0.5s
    (62, 0.5, 0.5),  # D4, start at 0.5s, duration 0.5s
    (64, 1.0, 0.5),  # E4, start at 1.0s, duration 0.5s
    (65, 1.5, 0.5),  # F4, start at 1.5s, duration 0.5s
    (67, 2.0, 0.5),  # G4, start at 2.0s, duration 0.5s
    (69, 2.5, 0.5),  # A4, start at 2.5s, duration 0.5s
    (71, 3.0, 0.5),  # B4, start at 3.0s, duration 0.5s
    (72, 3.5, 0.5)   # C5, start at 3.5s, duration 0.5s
]
sample_ns = triplets_to_note_sequence(sample_triplets)
print(f"Sample NoteSequence created with {len(sample_ns.notes)} notes and total time {sample_ns.total_time}s.")

In [None]:
print('Importing libraries and loading the trained model')
import magenta.music as mm
from magenta.models.music_vae import configs
from magenta.models.music_vae.trained_model import TrainedModel
import numpy as np
import os
import tensorflow.compat.v1 as tf
import random

tf.disable_v2_behavior()

# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
tf.set_random_seed(SEED)
random.seed(SEED)

mel_2bar_config = configs.CONFIG_MAP['cat-mel_2bar_big']

BASE_DIR="models/download.magenta.tensorflow.org/models/music_vae"
mel_2bar = TrainedModel(mel_2bar_config, batch_size=4, checkpoint_dir_or_path=BASE_DIR + '/checkpoints/mel_2bar_big.ckpt')

In [None]:
sampled, embedding, sd_embedding = mel_2bar.encode([sample_ns])
print(embedding)

In [None]:
embedding = mel_2bar.encode([sample_ns])
print(embedding)