In [2]:
import os
import sys
import boto3
import json
import numpy as np
import scipy.io.wavfile as sciwav
import IPython.display
import shutil

from magenta.models.nsynth import utils
from magenta.models.nsynth.wavenet.fastgen import encode
import numpy as np
import tensorflow.compat.v1 as tf

Instructions for updating:
Use tf.initializers.variance_scaling instead with distribution=uniform to get equivalent behavior.


In [3]:
bucket = 'music-ml-gigioli'
checkpoint_path = 'wavenet-ckpt/model.ckpt-200000'
batch_size = 16
sample_length = 64000

In [4]:
s3 = boto3.client('s3')
obj = s3.get_object(Bucket=bucket, Key='data/nsynth/nsynth-train/examples.json')

train_examples = json.loads(obj['Body'].read().decode('utf-8'))

In [None]:
wavfiles = list(train_examples.keys())

if os.path.exists('/tmp/nsynth_wav_files'):
    shutil.rmtree('/tmp/nsynth_wav_files')
os.mkdir('/tmp/nsynth_wav_files')
    

for start_file in range(0, len(wavfiles), batch_size):
    batch_number = (start_file / batch_size) + 1
    tf.logging.info("On file number %s (batch %d).", start_file, batch_number)
    end_file = start_file + batch_size
    batch_keys = wavfiles[start_file:end_file]
    
    batch_keys = ['data/nsynth/nsynth-train/audio/{}.wav'.format(x) for x in batch_keys]
    
    # download wav files locally
    for key in batch_keys:
        s3.download_file(bucket, key, '/tmp/nsynth_wav_files/{}'.format(key.split('/')[-1]))
        
    wavefiles_batch = ['/tmp/nsynth_wav_files/{}'.format(key.split('/')[-1]) for key in batch_keys]

    # Ensure that files has batch_size elements.
    batch_filler = batch_size - len(wavefiles_batch)
    wavefiles_batch.extend(batch_filler * [wavefiles_batch[-1]])
    wav_data = [utils.load_audio(f, sample_length) for f in wavefiles_batch]
    min_len = min([x.shape[0] for x in wav_data])
    wav_data = np.array([x[:min_len] for x in wav_data])
    
    try:
        tf.reset_default_graph()
        
        # Load up the model for encoding and find the encoding
        encoding = encode(wav_data, checkpoint_path, sample_length=sample_length)
        if encoding.ndim == 2:
            encoding = np.expand_dims(encoding, 0)

        tf.logging.info("Encoding:")
        tf.logging.info(encoding.shape)
        tf.logging.info("Sample length: %d" % sample_length)

#         for num, (wavfile, enc) in enumerate(zip(wavefiles_batch, encoding)):
#             filename = "%s_embeddings.npy" % wavfile.split("/")[-1].strip(".wav")
#         with tf.gfile.Open(os.path.join(save_path, filename), "w") as f:
#           np.save(f, enc)

#         if num + batch_filler + 1 == batch_size:
#           break
    except Exception as e:
        tf.logging.info("Unexpected error happened: %s.", e)
        raise
        
    break

#     # delete local wav file copies
#     for key in batch_keys:
#         os.remove('/tmp/nsynth_wav_files/{}'.format(key.split('/')[-1]))

In [5]:
wav_data = np.random.rand(16, 512)

In [8]:
if os.path.exists('/tmp/nsynth_wav_embeddings'):
    shutil.rmtree('/tmp/nsynth_wav_embeddings')
os.mkdir('/tmp/nsynth_wav_embeddings')

np.save(os.path.join('/tmp/nsynth_wav_embeddings', 'test.npy'), wav_data)

In [None]:
wavfiles_batch

In [12]:
!ls /tmp/nsynth_wav_files

bass_synthetic_016-080-127.wav	     keyboard_electronic_089-044-100.wav
bass_synthetic_120-108-050.wav	     keyboard_electronic_100-040-025.wav
guitar_acoustic_001-082-050.wav      organ_electronic_002-068-100.wav
guitar_electronic_021-026-025.wav    organ_electronic_011-079-075.wav
guitar_electronic_035-062-127.wav    organ_electronic_111-065-100.wav
keyboard_acoustic_010-084-127.wav    organ_electronic_120-050-127.wav
keyboard_acoustic_011-053-127.wav    vocal_synthetic_007-064-025.wav
keyboard_electronic_065-069-025.wav  vocal_synthetic_012-086-100.wav


In [10]:
!ls /tmp/nsynth_wav_embeddings

test.npy


In [None]:


FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string("source_path", "",
                           "The directory of WAVs to yield embeddings from.")
tf.app.flags.DEFINE_string("save_path", "", "The directory to save "
                           "the embeddings.")
tf.app.flags.DEFINE_string("checkpoint_path", "",
                           "A path to the checkpoint. If not given, the latest "
                           "checkpoint in `expdir` will be used.")
tf.app.flags.DEFINE_string("expdir", "",
                           "The log directory for this experiment. Required if "
                           "`checkpoint_path` is not given.")
tf.app.flags.DEFINE_integer("sample_length", 64000, "Sample length.")
tf.app.flags.DEFINE_integer("batch_size", 16, "Sample length.")
tf.app.flags.DEFINE_string("log", "INFO",
                           "The threshold for what messages will be logged."
                           "DEBUG, INFO, WARN, ERROR, or FATAL.")


def main(unused_argv=None):
  tf.logging.set_verbosity(FLAGS.log)

  if FLAGS.checkpoint_path:
    checkpoint_path = utils.shell_path(FLAGS.checkpoint_path)
  else:
    expdir = utils.shell_path(FLAGS.expdir)
    tf.logging.info("Will load latest checkpoint from %s.", expdir)
    while not tf.gfile.Exists(expdir):
      tf.logging.fatal("\tExperiment save dir '%s' does not exist!", expdir)
      sys.exit(1)

    try:
      checkpoint_path = tf.train.latest_checkpoint(expdir)
    except tf.errors.NotFoundError:
      tf.logging.fatal("There was a problem determining the latest checkpoint.")
      sys.exit(1)

  if not tf.train.checkpoint_exists(checkpoint_path):
    tf.logging.fatal("Invalid checkpoint path: %s", checkpoint_path)
    sys.exit(1)

  tf.logging.info("Will restore from checkpoint: %s", checkpoint_path)

  source_path = utils.shell_path(FLAGS.source_path)
  tf.logging.info("Will load Wavs from %s." % source_path)

  save_path = utils.shell_path(FLAGS.save_path)
  tf.logging.info("Will save embeddings to %s." % save_path)
  if not tf.gfile.Exists(save_path):
    tf.logging.info("Creating save directory...")
    tf.gfile.MakeDirs(save_path)

  sample_length = FLAGS.sample_length
  batch_size = FLAGS.batch_size

  def is_wav(f):
    return f.lower().endswith(".wav")

  wavfiles = sorted([
      os.path.join(source_path, fname)
      for fname in tf.gfile.ListDirectory(source_path) if is_wav(fname)
  ])

  for start_file in range(0, len(wavfiles), batch_size):
    batch_number = (start_file / batch_size) + 1
    tf.logging.info("On file number %s (batch %d).", start_file, batch_number)
    end_file = start_file + batch_size
    wavefiles_batch = wavfiles[start_file:end_file]

    # Ensure that files has batch_size elements.
    batch_filler = batch_size - len(wavefiles_batch)
    wavefiles_batch.extend(batch_filler * [wavefiles_batch[-1]])
    wav_data = [utils.load_audio(f, sample_length) for f in wavefiles_batch]
    min_len = min([x.shape[0] for x in wav_data])
    wav_data = np.array([x[:min_len] for x in wav_data])
    try:
      tf.reset_default_graph()
      # Load up the model for encoding and find the encoding
      encoding = encode(wav_data, checkpoint_path, sample_length=sample_length)
      if encoding.ndim == 2:
        encoding = np.expand_dims(encoding, 0)

      tf.logging.info("Encoding:")
      tf.logging.info(encoding.shape)
      tf.logging.info("Sample length: %d" % sample_length)

      for num, (wavfile, enc) in enumerate(zip(wavefiles_batch, encoding)):
        filename = "%s_embeddings.npy" % wavfile.split("/")[-1].strip(".wav")
        with tf.gfile.Open(os.path.join(save_path, filename), "w") as f:
          np.save(f, enc)

        if num + batch_filler + 1 == batch_size:
          break
    except Exception as e:
      tf.logging.info("Unexpected error happened: %s.", e)
      raise