# MusicVAE
paper: https://arxiv.org/pdf/1803.05428.pdf  
reference code: https://github.com/magenta/magenta/tree/master/magenta/models/music_vae   
datasets: https://magenta.tensorflow.org/datasets/groove

Bidirectional Encoder + Hierarchical Decoder(conductor + decoder)

## Set up Environment
The recommended environment is docker.  
If you are not familiar with docker environment, you can run this code locally in conda environment.

In [None]:
#### ONLY WHEN RUNNING LOCALLY ####
# If you run this notebook locally(not in a docker container), you are better to use conda environment
# Before you start to run this notebook, please install requirements pacakges and magenta

# conda create --name magenta python=3.7
# conda acitvate magneta

# pip install -r ./requirements.txt
# pip install -e .

## Preprocess Datasets
### Prepare Datasets

In [None]:
# Download datasets locally
# ! wget https://storage.googleapis.com/magentadata/datasets/groove/groove-v1.0.0-midionly.zip

In [None]:
# unzip datasets
# ! unzip groove-v1.0.0-midionly.zip

### Explore datasets

In [4]:
import pandas as pd
from pathlib import Path

In [5]:
data_dir = Path("./groove/")

In [6]:
df = pd.read_csv(data_dir / 'info.csv')
df.head()

Unnamed: 0,drummer,session,id,style,bpm,beat_type,time_signature,midi_filename,audio_filename,duration,split
0,drummer1,drummer1/eval_session,drummer1/eval_session/1,funk/groove1,138,beat,4-4,drummer1/eval_session/1_funk-groove1_138_beat_...,drummer1/eval_session/1_funk-groove1_138_beat_...,27.872308,test
1,drummer1,drummer1/eval_session,drummer1/eval_session/10,soul/groove10,102,beat,4-4,drummer1/eval_session/10_soul-groove10_102_bea...,drummer1/eval_session/10_soul-groove10_102_bea...,37.691158,test
2,drummer1,drummer1/eval_session,drummer1/eval_session/2,funk/groove2,105,beat,4-4,drummer1/eval_session/2_funk-groove2_105_beat_...,drummer1/eval_session/2_funk-groove2_105_beat_...,36.351218,test
3,drummer1,drummer1/eval_session,drummer1/eval_session/3,soul/groove3,86,beat,4-4,drummer1/eval_session/3_soul-groove3_86_beat_4...,drummer1/eval_session/3_soul-groove3_86_beat_4...,44.716543,test
4,drummer1,drummer1/eval_session,drummer1/eval_session/4,soul/groove4,80,beat,4-4,drummer1/eval_session/4_soul-groove4_80_beat_4...,drummer1/eval_session/4_soul-groove4_80_beat_4...,47.9875,test


In [7]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1150 entries, 0 to 1149
Data columns (total 11 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   drummer         1150 non-null   object 
 1   session         1150 non-null   object 
 2   id              1150 non-null   object 
 3   style           1150 non-null   object 
 4   bpm             1150 non-null   int64  
 5   beat_type       1150 non-null   object 
 6   time_signature  1150 non-null   object 
 7   midi_filename   1150 non-null   object 
 8   audio_filename  1090 non-null   object 
 9   duration        1150 non-null   float64
 10  split           1150 non-null   object 
dtypes: float64(1), int64(1), object(9)
memory usage: 99.0+ KB


### Preprocess datasets with magenta
Magneta provides convert midi data to note sequences.  
We use this script to convert groove midi datasets to note sequences.  
sciprt: magenta/magenta/scripts/convert_dir_to_note_sequences.py

- note sequences look like
```
ticks_per_quarter: 480
time_signatures {
  numerator: 4
  denominator: 4
}
key_signatures {
}
tempos {
  qpm: 95.00014250021376
}
notes {
  pitch: 44
  velocity: 30
  start_time: 1.9473654999999999
  end_time: 2.0486811374999996
  is_drum: true
}
notes {
  pitch: 38
  velocity: 27
  start_time: 2.4434173875
  end_time: 2.4934173124999996
  is_drum: true
}
```

In [None]:
import shutil

In [None]:
# before run script we need to split train and test data
train_dir = data_dir / "train"
val_dir = data_dir / "validation"
test_dir = data_dir / "test"

if not train_dir.exists():
    Path.mkdir(train_dir)
if not val_dir.exists():
    Path.mkdir(val_dir)
if not test_dir.exists():
    Path.mkdir(test_dir)

In [None]:
train_data_path_list = df[df["split"] == "train"]["midi_filename"].to_list()
val_data_path_list = df[df["split"] == "validation"]["midi_filename"].to_list()
test_data_path_list = df[df["split"] == "test"]["midi_filename"].to_list()

print(len(train_data_path_list), len(val_data_path_list), len(test_data_path_list))

In [None]:
# There are some overlapped names, so we add prefix to file name
for train_data_path in train_data_path_list:
    prefix = Path(train_data_path).parents[1]
    shutil.copy(data_dir / train_data_path, train_dir/f"{prefix}_{Path(train_data_path).name}")
for val_data_path in val_data_path_list:
    prefix = Path(val_data_path).parents[1]
    shutil.copy(data_dir / val_data_path, val_dir/f"{prefix}_{Path(val_data_path).name}")
for test_data_path in test_data_path_list:
    prefix = Path(test_data_path).parents[1]
    shutil.copy(data_dir / test_data_path, test_dir/f"{prefix}_{Path(test_data_path).name}")

In [None]:
print(len(list(train_dir.glob("*.mid"))), len(list(val_dir.glob("*.mid"))), len(list(test_dir.glob("*.mid"))))

In [None]:
! mkdir preprocessed_datasets
! python magenta/scripts/convert_dir_to_note_sequences.py --input_dir=./groove/train --output_file=./preprocessed_datasets/groove_train.tfrecord --recursive=True --log=INFO
! python magenta/scripts/convert_dir_to_note_sequences.py --input_dir=./groove/val --output_file=./preprocessed_datasets/groove_validation.tfrecord --recursive=True --log=INFO
! python magenta/scripts/convert_dir_to_note_sequences.py --input_dir=./groove/test --output_file=./preprocessed_datasets/groove_test.tfrecord --recursive=True --log=INFO

## Build MusicVAE model

### Make MusicVAE config

In [None]:
import magenta

import collections

from magenta.models.music_vae.configs import Config, HParams, CONFIG_MAP

from magenta.common import merge_hparams
from magenta.contrib import training as contrib_training
from magenta.models.music_vae import data
from magenta.models.music_vae import data_hierarchical
from magenta.models.music_vae import lstm_models
from magenta.models.music_vae.base_model import MusicVAE

import note_seq

In [None]:
# parameters in paper

NUM_FEATURES = 2 ** 9
NUM_ENCODER_LAYERS = 2
ENCODER_HIDDEN_SIZE = 2048
LATENT_DIM = 512
NUM_SEGMENTS = 4
CONDUCTOR_HIDDEN_SIZE = 1024
CONDUCTOR_OUTPUT_SIZE = 512
NUM_DECODER_LAYERS = 2
DECODER_HIDDEN_SIZE = 1024
DECODER_OUTPUT_SIZE = NUM_FEATURES


# hyperparameters

NUM_BARS = 4
MAX_SEQ_LENGTH = NUM_BARS * NUM_SEGMENTS
BATCH_SIZE = 512
MAX_BETA = 0.2 # lower beta means better reconstrucction
FREE_BITS = 48


In [None]:
CONFIG_MAP['musicvae_groove_4bar'] = Config(
    model=MusicVAE(
        lstm_models.BidirectionalLstmEncoder(),
        lstm_models.HierarchicalLstmDecoder(
            lstm_models.CategoricalLstmDecoder(),
            level_lengths=[4, 4])
        ),
    hparams=merge_hparams(
        lstm_models.get_default_hparams(),
        HParams(
            batch_size=BATCH_SIZE,
            max_seq_len=MAX_SEQ_LENGTH,  
            z_size=LATENT_DIM,
            enc_rnn_size=[ENCODER_HIDDEN_SIZE],
            dec_rnn_size=[DECODER_HIDDEN_SIZE, DECODER_HIDDEN_SIZE],
            max_beta=MAX_BETA,
            free_bits=FREE_BITS,
            dropout_keep_prob=0.3,
        )),
    note_sequence_augmenter=None,
    data_converter=data.DrumsConverter(
        max_bars=100,  # Truncate long drum sequences before slicing.
        slice_bars=4,
        steps_per_quarter=4,
        roll_input=True),
    train_examples_path="./preprocessed_datasets/groove_train.tfrecord",  # the train tfrecord file path
    eval_examples_path="./preprocessed_datasets/groove_validation.tfrecord" # the eval tfrecord file path
)

## Train MusicVAE

We use the script(https://github.com/magenta/magenta/blob/main/magenta/models/music_vae/music_vae_train.py) to train the model

In [None]:
from magenta.models.music_vae.music_vae_train import main, flags, FLAGS, train, evaluate
from magenta.models.music_vae import configs
import tensorflow.compat.v1 as tf
import os

### Set arguments for training


In [None]:
# from argparse import ArgumentParser

# parser = ArgumentParser(description="Argument for training")
# parser.add_argument("examples_path", type=str, default="./preprocessed_datasets/groove.tfrecord", help="Path to a TFRecord file of NoteSequence examples. Overrides the config.")
# parser.add_argument("tfds_name", type=str, default='', help="TensorFlow Datasets dataset name to use. Overrides the config.")
# parser.add_argument("run_dir", type=str, default="./outputs/", help="Path where checkpoints and summary events will be located during training and evaluation. Separate subdirectories 'train' and 'eval' will be created within this directory.")
# parser.add_argument("num_steps", type=int, default=50000, help="Number of training steps or `None` for infinite.")
# parser.add_argument("eval_num_batches", type=int, default=2000, help="Number of batches to use during evaluation or 'None' for all batches in the data source.")
# parser.add_argument("checkpoints_to_keep", type=int, default=25, help="Maximum number of checkpoints to keep in 'train' mode or 0 for infinite.")
# parser.add_argument("mode", type=str, default="train", help="train or eval")
# parser.add_argument("log", type=str, default="INFO", help="DEBUG, INFO, WARN, ERROR, or FATAL.")

args = {
    "master": '',
    "examples_path": "", # we have already set examples path in config
    "tfds_name": "",
    "run_dir": "./outputs/",
    "num_steps": 500000,
    "eval_num_batches": 2000,
    "checkpoints_to_keep": 25,
    "keep_checkpoint_every_n_hours": 1,
    "mode": "train",
    "log": "INFO",
    "hparams": '',
    "cache_dataset": True,
    "task": 0,
    "num_ps_tasks": 0,
    'num_sync_workers': 0,
    'eval_dir_suffix': '',
        }


In [None]:
tf.disable_v2_behavior()
tf.logging.set_verbosity(args["log"])

def run(config_map, args, tf_file_reader=tf.data.TFRecordDataset, file_reader=tf.python_io.tf_record_iterator):
    if not args["run_dir"]:
        raise ValueError('Invalid run directory: %s' % args["run_dir"])
    run_dir = os.path.expanduser(args["run_dir"])
    train_dir = os.path.join(run_dir, 'train')
    
    if args["mode"] not in ['train', 'eval']:
      raise ValueError('Invalid mode: %s' % args["mode"])

    config = config_map
    if args["hparams"]:
      config.hparams.parse(args["hparams"])
    config_update_map = {}
    if args["examples_path"]:
      config_update_map['%s_examples_path' % args["mode"]] = os.path.expanduser(args["examples_path"])
    if args["tfds_name"]:
      if args["examples_path"]:
        raise ValueError(
            'At most one of --examples_path and --tfds_name can be set.')
      config_update_map['tfds_name'] = args["tfds_name"]
      config_update_map['eval_examples_path'] = None
      config_update_map['train_examples_path'] = None
    config = configs.update_config(config, config_update_map)

    if args["mode"] == 'train':
      is_training = True
    elif args["mode"] == 'eval':
      is_training = False
    else:
      raise ValueError('Invalid mode: {}'.format(args["mode"]))

    def dataset_fn():
      return data.get_dataset(
          config,
          tf_file_reader=tf_file_reader,
          is_training=is_training,
          cache_dataset=args["cache_dataset"])

    if is_training:
      train(
          train_dir,
          config=config,
          dataset_fn=dataset_fn,
          checkpoints_to_keep=args["checkpoints_to_keep"],
          keep_checkpoint_every_n_hours=args["keep_checkpoint_every_n_hours"],
          num_steps=args["num_steps"],
          master=args["master"],
          num_sync_workers=args["num_sync_workers"],
          num_ps_tasks=args["num_ps_tasks"],
          task=args["task"])
    else:
      num_batches = args["eval_num_batches"] or data.count_examples(
          config.eval_examples_path,
          config.tfds_name,
          config.data_converter,
          file_reader) // config.hparams.batch_size
      eval_dir = os.path.join(run_dir, 'eval' + args["eval_dir_suffix"])
      evaluate(
          train_dir,
          eval_dir,
          config=config,
          dataset_fn=dataset_fn,
          num_batches=num_batches,
          master=args["master"])

In [None]:
run(CONFIG_MAP['musicvae_groove_4bar'], args)

## Generate Samples
Magenta also provides scripts to generate samples from trained models  
scripts: magenta/magenta/models/music_vae/music_vae_generate.py

pretrained model link: https://drive.google.com/file/d/1ALoQpdyUI5oHJCqe92Cb_oimv1CXP3cN/view?usp=sharing

In [None]:
from magenta.models.music_vae.trained_model import TrainedModel
import sys
import time
import numpy as np

In [None]:
# Set your checkpoint path. It will be .tar file including .index and .data
checkpoint_path = Path('./outputs/model.ckpt-50000.tar')

In [None]:
generation_args = {
    "run_dir": None,
    "checkpoint_file": checkpoint_path,
    "output_dir": "./outputs/generation_samples/",
    "mode": "sample",
    "input_midi_1": None,
    "input_midi_2": None,
    "num_outputs": 5,
    "max_batch_size": 8,
    "temperature": 0.5,
    "log": "INFO"
}

In [None]:
tf.disable_v2_behavior()
logging = tf.logging
logging.set_verbosity(generation_args["log"])

In [None]:
def generation_run(config_map, generation_args):
  """Load model params, save config file and start trainer.
  Args:
    config_map: Dictionary mapping configuration name to Config object.
  Raises:
    ValueError: if required flags are missing or invalid.
  """
  date_and_time = time.strftime('%Y-%m-%d_%H%M%S')

  if generation_args["run_dir"] is None == generation_args["checkpoint_file"] is None:
    raise ValueError(
        'Exactly one of `--run_dir` or `--checkpoint_file` must be specified.')
  if generation_args["output_dir"] is None:
    raise ValueError('`--output_dir` is required.')
  tf.gfile.MakeDirs(generation_args["output_dir"])
  if generation_args["mode"] != 'sample' and generation_args["mode"] != 'interpolate':
    raise ValueError('Invalid value for `--mode`: %s' % generation_args["mode"])

  config = config_map
  config.data_converter.max_tensors_per_item = None

  if generation_args["mode"] == 'interpolate':
    if generation_args["input_midi_1"] is None or generation_args["input_midi_2"] is None:
      raise ValueError(
          '`--input_midi_1` and `--input_midi_2` must be specified in '
          '`interpolate` mode.')
    input_midi_1 = os.path.expanduser(generation_args["input_midi_1"])
    input_midi_2 = os.path.expanduser(generation_args["input_midi_2"])
    if not os.path.exists(input_midi_1):
      raise ValueError('Input MIDI 1 not found: %s' % generation_args["input_midi_1"])
    if not os.path.exists(input_midi_2):
      raise ValueError('Input MIDI 2 not found: %s' % generation_args["input_midi_2"])
    input_1 = note_seq.midi_file_to_note_sequence(input_midi_1)
    input_2 = note_seq.midi_file_to_note_sequence(input_midi_2)

    def _check_extract_examples(input_ns, path, input_number):
      """Make sure each input returns exactly one example from the converter."""
      tensors = config.data_converter.to_tensors(input_ns).outputs
      if not tensors:
        print(
            'MusicVAE configs have very specific input requirements. Could not '
            'extract any valid inputs from `%s`. Try another MIDI file.' % path)
        sys.exit()
      elif len(tensors) > 1:
        basename = os.path.join(
            generation_args["output_dir"],
            '%s_input%d-extractions_%s-*-of-%03d.mid' %
            ("musicvae_groove_4bar", input_number, date_and_time, len(tensors)))
        for i, ns in enumerate(config.data_converter.from_tensors(tensors)):
          note_seq.sequence_proto_to_midi_file(
              ns, basename.replace('*', '%03d' % i))
        print(
            '%d valid inputs extracted from `%s`. Outputting these potential '
            'inputs as `%s`. Call script again with one of these instead.' %
            (len(tensors), path, basename))
        sys.exit()
    logging.info(
        'Attempting to extract examples from input MIDIs using config `%s`...',
        "musicvae_groove_4bar")
    _check_extract_examples(input_1, generation_args["input_midi_1"], 1)
    _check_extract_examples(input_2, generation_args["input_midi_2"], 2)

  logging.info('Loading model...')
  if generation_args["run_dir"]:
    checkpoint_dir_or_path = os.path.expanduser(
        os.path.join(generation_args["run_dir"], 'train'))
  else:
    checkpoint_dir_or_path = os.path.expanduser(generation_args["checkpoint_file"])
    print(checkpoint_dir_or_path)
  model = TrainedModel(
      config, batch_size=min(generation_args["max_batch_size"], generation_args["num_outputs"]),
      checkpoint_dir_or_path=checkpoint_dir_or_path)

  if generation_args["mode"] == 'interpolate':
    logging.info('Interpolating...')
    _, mu, _ = model.encode([input_1, input_2])
    z = np.array([
        _slerp(mu[0], mu[1], t) for t in np.linspace(0, 1, generation_args["num_outputs"])])
    results = model.decode(
        length=config.hparams.max_seq_len,
        z=z,
        temperature=generation_args["temperature"])
  elif generation_args["mode"] == 'sample':
    logging.info('Sampling...')
    results = model.sample(
        n=generation_args["num_outputs"],
        length=config.hparams.max_seq_len,
        temperature=generation_args["temperature"])

  basename = os.path.join(
      generation_args["output_dir"],
      '%s_%s_%s-*-of-%03d.mid' %
      ("musicvae_groove_4bar", generation_args["mode"], date_and_time, generation_args["num_outputs"]))
  logging.info('Outputting %d files as `%s`...', generation_args["num_outputs"], basename)
  for i, ns in enumerate(results):
    note_seq.sequence_proto_to_midi_file(ns, basename.replace('*', '%03d' % i))

  logging.info('Done.')


In [None]:
generation_run(CONFIG_MAP["musicvae_groove_4bar"], generation_args)