In [7]:
# bazel build //magenta/models/melody_rnn:melody_rnn_generate
# bazel run //magenta/models/melody_rnn:melody_rnn_generate -- --config=...


# Arrangement RNN Model

In [8]:
import collections
import functools


# internal imports

import tensorflow as tf

from magenta.models.shared import events_rnn_model
from magenta.music.performance_lib import PerformanceEvent

In [37]:
class ArrangementRnnModel(events_rnn_model.EventSequenceRnnModel):
    def __init__():
        pass
    
class ArrangementRnnConfig(events_rnn_model.EventSequenceRnnConfig):
    """Stores a configuration for a Arrangement RNN.

    Attributes:
        num_velocity_bins: Number of velocity bins to use. If 0, don't use velocity
        at all.
        control_signals: List of PerformanceControlSignal objects to use for
        conditioning, or None if not conditioning on anything.
        optional_conditioning: If True, conditioning can be disabled by setting a
        flag as part of the conditioning input.
    """
    def __init__(self, details, encoder_decoder, hparams, num_velocity_bins=0, 
                 control_signals=None, optional_conditioning=False):
        
        if control_signals is not None:
            control_encoder = magenta.music.MultipleEventSequenceEncoder(
                [control.encoder for control in control_signals])
            if optional_conditioning:
                control_encoder = magenta.music.OptionalEventSequenceEncoder(
                    control_encoder)
                encoder_decoder = magenta.music.ConditionalEventSequenceEncoderDecoder(
                    control_encoder, encoder_decoder)
        
        super(ArrangementRnnConfig, self).__init__(
            details, encoder_decoder, hparams)
        self.num_velocity_bins = num_velocity_bins
        self.control_signals = control_signals
        self.optional_conditioning = optional_conditioning

In [35]:
default_configs = {
    'baseline': ArrangementRnnConfig(
        magenta.protobuf.generator_pb2.GeneratorDetails(
            id='baseline',
            description='The baseline model for Arrangement RNN.'),
        magenta.music.OneHotEventSequenceEncoderDecoder(
            magenta.music.PerformanceOneHotEncoding()),
        tf.contrib.training.HParams(
            batch_size=64,
            rnn_layer_sizes=[512, 512, 512],
            dropout_keep_prob=1.0,
            clip_norm=3,
            learning_rate=0.001)),
}

# Train

In [15]:
import os
import tensorflow as tf
from magenta.models.shared import events_rnn_graph
from magenta.models.shared import events_rnn_train

In [6]:
# Path to the directory where checkpoints and
# summary events will be saved during training and
# evaluation. Separate subdirectories for training
# events and eval events will be created within
# `run_dir`. Multiple runs can be stored within the
# parent directory of `run_dir`. Point TensorBoard
# to the parent directory of `run_dir` to see all
# your runs.
RUN_DIR = './tmp/run_logs/'

CONFIG = 'baseline'

# Path to TFRecord file containing 
# tf.SequenceExample records for training or 
# evaluation.
SEQUENCE_EXAMPLE_FILE = './tmp/sequence_examples/training_performances.tfrecord'

# The the number of global training steps your
# model should take before exiting training.
# Leave as 0 to run until terminated manually.
NUM_TRAINING_STEPS = 0

# The number of evaluation examples your model
# should process for each evaluation step.
# Leave as 0 to use the entire evaluation set.
NUM_EVAL_EXAMPLES = 0

# A summary statement will be logged every
# `summary_frequency` steps during training or
# every `summary_frequency` seconds during
# evaluation.
SUMMARY_FREQUENCY = 10

# The number of most recent checkpoints to keep in
# the training directory. Keeps all if 0.
NUM_CHECKPOINTS = 10

# If True, this process only evaluates the model
# and does not update weights.
EVAL = False

# The threshold for what messages will be logged
# DEBUG, INFO, WARN, ERROR, or FATAL.
LOG = 'INFO'

# Comma-separated list of `name=value` pairs. For
# each pair, the value of the hyperparameter named
# `name` is set to `value`. This mapping is merged
# with the default hyperparameters.
HPARAMS = ''

In [49]:
def train():
    tf.logging.set_verbosity(LOG)

    if not RUN_DIR:
        tf.logging.fatal('RUN_DIR required')
        return
    if not SEQUENCE_EXAMPLE_FILE:
        tf.logging.fatal('SEQUENCE_EXAMPLE_FILE required')
        return

    sequence_example_file_paths = tf.gfile.Glob(
          os.path.expanduser(SEQUENCE_EXAMPLE_FILE))

    run_dir = os.path.expanduser(RUN_DIR)

    config = default_configs[CONFIG]
    config.hparams.parse(HPARAMS)

    mode = 'eval' if EVAL else 'train'

    # A lot of things are happening here
    #
    # events_rnn_graph.build_graph_fn returns a function 
    # which builds the TF ops when called. That function is later
    # invoked from inside events_rnn_train
    #
    #
    # events_rnn_graph makes a call to 
    # magenta.common.sequence_example_lib.get_padded_batch
    # which reads batches of SequenceExamples from TFRecords and
    # pads them to the length of the longest sequence.
    # get_padded_batch returns tf.train.batch with
    #     inputs:  A tensor of shape [batch_size, num_steps, input_size] of floats32s.
    #     labels:  A tensor of shape [batch_size, num_steps] of int64s.
    #     lengths: A tensor of shape [batch_size] of int32s. The lengths of each
    #              SequenceExample before padding.
    #
    #     batch_size and input_size are arguments. input_size is the size of 
    #     each input vector.
    #
    build_graph_fn = events_rnn_graph.get_build_graph_fn(
        mode, config, sequence_example_file_paths)

    train_dir = os.path.join(RUN_DIR, 'train')
    tf.gfile.MakeDirs(train_dir)
    tf.logging.info('Train dir: %s', train_dir)
    
    
    if EVAL: # Evaluate only
        eval_dir = os.path.join(run_dir, 'eval')
        tf.gfile.MakeDirs(eval_dir)
        tf.logging.info('Eval dir: %s', eval_dir)
        num_batches = (
            (NUM_EVAL_EXAMPLES if NUM_EVAL_EXAMPLES else
            magenta.common.count_records(sequence_example_file_paths)) //
            config.hparams.batch_size)
        events_rnn_train.run_eval(build_graph_fn, train_dir, eval_dir, num_batches)
    
    else: # Update weights & Evaluate
        events_rnn_train.run_training(build_graph_fn, train_dir,
                                        NUM_TRAINING_STEPS,
                                        SUMMARY_FREQUENCY,
                                        checkpoints_to_keep=NUM_CHECKPOINTS)

In [48]:
train()

INFO:tensorflow:hparams = {'batch_size': 64, 'rnn_layer_sizes': [512, 512, 512], 'dropout_keep_prob': 1.0, 'clip_norm': 3, 'learning_rate': 0.001}
INFO:tensorflow:Train dir: ./tmp/run_logs/train
INFO:tensorflow:Counting records in ./tmp/sequence_examples/training_performances.tfrecord.
INFO:tensorflow:Number of records is at least 100.
INFO:tensorflow:[<tf.Tensor 'random_shuffle_queue_Dequeue:0' shape=(?, 356) dtype=float32>, <tf.Tensor 'random_shuffle_queue_Dequeue:1' shape=(?,) dtype=int64>, <tf.Tensor 'random_shuffle_queue_Dequeue:2' shape=() dtype=int32>]
INFO:tensorflow:Starting training loop...
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from ./tmp/run_logs/train/model.ckpt-5
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 6 into ./tmp/run_logs/train/model.ckpt.
INFO:tensorflow:Global Step = 6, Loss = 4.82739, Perplexity = 124.884, Ac

KeyboardInterrupt: 

In [None]:
sess = tf.InteractiveSession()

In [None]:
from magenta.common import get_padded_batch
import tensorflow as tf
import os

sequence_example_file_paths = tf.gfile.Glob(
      os.path.expanduser(SEQUENCE_EXAMPLE_FILE))
    
    
inputs, labels, lengths = get_padded_batch(sequence_example_file_paths, 5, 10)

In [None]:
inputs