In [None]:
!pip uninstall ddsp

In [None]:
!which python

In [None]:
def print_dict(d):
    for k, v in d.items():
        print(f"\t'{k}': {v}")

def print_dict2(d):
    l = 0
    for detail in d:
        print()
        print(f'item #{l}:')
        l = l + 1

        i = 0
        for k, v in detail.items():
            if i == 0:
                print(f"\t'{k}': {v}")
            else:
                print(f"\t\t'{k}': {v}")

In [None]:
import os
import platform
import sys

if platform.system() == 'Windows':
    midi_ddsp_module_path = os.path.abspath(os.path.join('../../'))
    ddsp_module_path = os.path.abspath(os.path.join('../../../ddsp-playground-2/'))
else:
    midi_ddsp_module_path = os.path.abspath(os.path.join('../../'))
    ddsp_module_path = os.path.abspath(os.path.join('../../../ddsp/ddsp-playground-2/'))

def apply_module_path(module_path):
    print(f"module_path={module_path}")
    if module_path not in sys.path:
      sys.path.append(module_path)
      print(f"appending {module_path} to sys.path")
    else:
      print(f"do not appending {module_path} to sys.path")

apply_module_path(midi_ddsp_module_path)
apply_module_path(ddsp_module_path)

import sys
if platform.system() != 'Windows':
    sparsenet_module_path_abs = '/ssd003/home/burakovr/projects/vova/envs/main/lib/python3.8/site-packages/'
    #apply_module_path(sparsenet_module_path_abs)

    libs_modules_path_abs = '/ssd003/home/burakovr/projects/vova/envs/ddsp/lib/python3.8/site-packages/'
    apply_module_path(libs_modules_path_abs)

import midi_ddsp

In [None]:
# Using imports from original Magenta's 'train_synthesis_generator.py'. This is not plagiarism, please pay more attention to the code 

#  Copyright 2022 The MIDI-DDSP Authors.
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#      http://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

"""Training code for Synthesis Generator."""

import tensorflow as tf
import time
import os
import sys
import logging
import argparse
import IPython

from keras.utils.layer_utils import print_summary
from livelossplot import PlotLosses

from midi_ddsp.data_handling.get_dataset import get_dataset
from midi_ddsp.utils.training_utils import print_hparams, set_seed, \
    save_results, str2bool
from midi_ddsp.utils.summary_utils import write_tensorboard_audio
#                          from midi_ddsp.hparams_synthesis_generator import hparams as hp
from midi_ddsp.hparams_synthesis_generator import hparams_debug as hp
from midi_ddsp.modules.recon_loss import ReconLossHelper
from midi_ddsp.modules.gan_loss import GANLossHelper
from midi_ddsp.modules.get_synthesis_generator import get_synthesis_generator, \
    get_fake_data_synthesis_generator
from midi_ddsp.modules.discriminator import Discriminator
from ddsp.colab.notebook_utils import play, specplot


In [None]:
print(tf.__version__)

In [None]:
plotlosses = PlotLosses()


losses_history = [] # per step
eval_losses_history = [] # per epoch

In [None]:
def load_tflite_model(model_path):
    interpreter = tf.lite.Interpreter(model_path=model_path)

    input_details = interpreter.get_input_details()
    #print(f"input_details={[detail['name'] for detail in input_details]}")
    print(f"\n\n Inputs (len={len(input_details)}):")
    #print(input_details)
    print_dict2(input_details)
    
    output_details = interpreter.get_output_details()
    print(f"\n\n Outputs (len={len(output_details)}):")
    print_dict2(output_details)
    
    return interpreter, input_details, output_details

In [None]:
def run_tflite_inference(interpreter, inputs):
    
    def set_inputs_for_tflite(interpreter, interpreter_input):
    #interpreter_input = eval_sample_batch

        for key, value in interpreter_input.items():
            print(f"")
            #print(f"#{i}: trying to set {key}. It's value is {value}")

            good_detail = None
            for detail in input_details:
                #print(detail['name'][16:-2])
                if detail['name'][16:-2] == key or detail['name'] == key:
                    good_detail = detail
                    break

            if good_detail is None:
                print(f"Unable to find {key}. Skipping")
            else:

                print(f"Our raw input value: {tf.shape(value)}")
                print(f"trying to set {good_detail['name']} with tensor of shape {tf.shape(value)}")

                interpreter.set_tensor(good_detail['index'], tf.cast(value, good_detail['dtype']))

    
    def get_outputs_from_tflite(interpreter):
        #f0_hz  = interpreter.get_tensor(input_details[1]['index'])
        #amps   = interpreter.get_tensor(output_details[11]['index'])
        #noises = interpreter.get_tensor(output_details[14]['index'])
        #hd     = interpreter.get_tensor(output_details[13]['index'])

        f0_hz  = interpreter.get_tensor(input_details[1]['index'])
        amps   = interpreter.get_tensor(output_details[0]['index'])
        noises = interpreter.get_tensor(output_details[3]['index'])
        hd     = interpreter.get_tensor(output_details[2]['index'])
        
        #amps    = tf.reshape(amps,      [1, 1, 1])
        #noises  = tf.reshape(noises,    [1, 1, 60])
        #hd      = tf.reshape(hd,        [1, 1, 65])

        synth_params = {
            'f0_hz': f0_hz,
            'amplitudes': amps,
            'harmonic_distribution': hd,
            'noise_magnitudes': noises
        }

        return synth_params
    
    interpreter.allocate_tensors()
    
    set_inputs_for_tflite(interpreter=interpreter, interpreter_input=inputs)
    
    interpreter.invoke()
    
    tflite_synth_params = get_outputs_from_tflite(interpreter)
    
    return tflite_synth_params

In [None]:
import tensorflow_io as tfio

def resample_tensor(tensor, input_rate, output_rate):
    if len(tensor.shape) == 2:
        tensor = tf.expand_dims(tensor, axis=-1)  # Convert to shape [1, 1000, 1]

    # Resample the tensor
    resampled_tensor = tfio.audio.resample(tensor, input_rate, output_rate)

    if len(tensor.shape) == 2:
        resampled_tensor = tf.squeeze(resampled_tensor, axis=-1)  # Convert back to shape [1, 300]

    return resampled_tensor

In [None]:
def get_vst_inputs_for_batch(source_batch, buffer_size_in_samples, n_frames_in_buffer):

    def get_vst_inputs_for_buffer(buffer_idx, buffer_size_in_samples, n_frames_in_buffer, source_batch):    

        dataset_frame_size = source_batch['audio'].shape[1] // source_batch['f0_hz'].shape[1] # 64
        n_dataset_frames_in_batch  = source_batch['audio'].shape[1] // dataset_frame_size
        n_dataset_frames_in_buffer = buffer_size_in_samples // dataset_frame_size             # 1000
        n_dataset_frames_in_one_user_frame = n_dataset_frames_in_buffer // n_frames_in_buffer # 2 ds_frames in one frame
        ds_frame_pos = buffer_idx * n_dataset_frames_in_one_user_frame # on which frame the given buffer starts
        n_user_frames_in_batch = source_batch['f0_hz'].shape[1] // n_dataset_frames_in_one_user_frame # 500
        
        bs = tf.shape(source_batch['audio'])[0]
        
        #resized_batch = tf.image.resize(source_batch, new_shape, method=tf.image.ResizeMethod.LINEAR)
        resized_audio         = source_batch['audio']
        resized_loudness_db   = resample_tensor(source_batch['loudness_db'], n_dataset_frames_in_batch, n_user_frames_in_batch)
        resized_f0_hz         = resample_tensor(source_batch['f0_hz'],       n_dataset_frames_in_batch, n_user_frames_in_batch)
        resized_midi          = resample_tensor(source_batch['midi'],        n_dataset_frames_in_batch, n_user_frames_in_batch)
        resized_onsets        = resample_tensor(source_batch['onsets'],      n_dataset_frames_in_batch, n_user_frames_in_batch)
        resized_offsets       = resample_tensor(source_batch['offsets'],     n_dataset_frames_in_batch, n_user_frames_in_batch)
        resized_instrument_id = source_batch['instrument_id']

        
        audio         = resized_audio[0][buffer_idx*buffer_size_in_samples:(buffer_idx+1)*buffer_size_in_samples]
        loudness_db   = tf.reshape(resized_loudness_db[0][ds_frame_pos:ds_frame_pos+n_dataset_frames_in_buffer], [n_frames_in_buffer])
        f0_hz         = tf.reshape(resized_f0_hz[0][ds_frame_pos:ds_frame_pos+n_dataset_frames_in_buffer], [n_frames_in_buffer])
        midi          = resized_midi[0][ds_frame_pos:ds_frame_pos+n_dataset_frames_in_buffer]
        onsets        = resized_onsets[0][ds_frame_pos:ds_frame_pos+n_dataset_frames_in_buffer]
        offsets       = resized_offsets[0][ds_frame_pos:ds_frame_pos+n_dataset_frames_in_buffer]
        instrument_id = tf.reshape(resized_instrument_id[0], [1])

        return {'audio': audio, 'loudness_db': loudness_db, 'f0_hz': f0_hz, 'midi': midi, 
                'onsets': onsets, 'offsets': offsets, 'instrument_id': instrument_id}
    
    n_buffers = source_batch['audio'].shape[1] // buffer_size_in_samples
    
    for buffer_idx in range(n_buffers):
        yield get_vst_inputs_for_buffer(buffer_idx=buffer_idx, 
                                       buffer_size_in_samples=buffer_size_in_samples,
                                       n_frames_in_buffer=n_frames_in_buffer,
                                       source_batch=source_batch)
    

In [None]:
buffer_size_in_samples = 64000
n_frames_in_buffer = 500
frame_size = buffer_size_in_samples // n_frames_in_buffer

print(frame_size)

In [None]:
accum_sound = None

for vst_input_buffer in get_vst_inputs_for_batch(source_batch=eval_sample_batch,
                                 buffer_size_in_samples=buffer_size_in_samples,
                                 n_frames_in_buffer=n_frames_in_buffer):
    
    res = model_vst(**vst_input_buffer, state=tf.random.uniform([512]))
    print('finished running the model')
        
    res['f0_hz']                 = tf.reshape(vst_input_buffer['f0_hz'],    [1, n_frames_in_buffer, 1])
    res['amplitudes']            = tf.reshape(res['amplitudes'],            [1, n_frames_in_buffer, 1])
    res['harmonic_distribution'] = tf.reshape(res['harmonic_distribution'], [1, n_frames_in_buffer, 60])
    res['noise_magnitudes']      = tf.reshape(res['noise_magnitudes'],      [1, n_frames_in_buffer, 65])

    my_processor_group_vst = get_process_group(n_frames=n_frames_in_buffer, frame_size=frame_size, sample_rate=16000, use_angular_cumsum=False)
    
    print(f'res.keys={res.keys()}')
    
    my_control_params = my_processor_group_vst.get_controls(res, verbose=False)
    my_synth_audio = my_processor_group_vst.get_signal(my_control_params)
    
    if accum_sound is None:
        accum_sound = my_synth_audio
    else:
        accum_sound = tf.concat([accum_sound, my_synth_audio], axis=1)
    
print("finally..")
play(accum_sound)

In [None]:
def append_postfix_to_filename(file_path, postfix):
    # Split the path into folder, file base, and file extension
    folder, file_name = os.path.split(file_path)
    file_base, file_ext = os.path.splitext(file_name)
    
    # Add the postfix to the file base
    new_file_base = file_base + postfix

    # Combine the new file base with the file extension
    new_file_name = new_file_base + file_ext

    # Combine the folder with the new file name
    new_file_path = os.path.join(folder, new_file_name)
    
    return new_file_path

In [None]:
def export_to_tflite(ae_model, path):
    #ae_copy = get_synthesis_generator(hp)
    #ae_copy._build(get_fake_data_synthesis_generator(hp))
    #ae_copy(train_sample_batch)
    #ae_copy.set_weights(ae_model.get_weights())
    
    orig_run_without_synths = ae_model.run_without_synths
    orig_run_inside_vst = ae_model.run_inside_vst
    orig_run_synth_coder_only = ae_model.run_synth_coder_only

    ae_model.run_without_synths = True
    ae_model.run_inside_vst = True
    ae_model.run_synth_coder_only = False

    #ae_copy.reverb_module = None
    #ae_copy.processor_group = None
    
    
    model_vst = MIDIExpressionAE_VST_IO_Wrapper(ae_model=model, vst_buffer_size=1024, vst_frame_size=1024)
    
    # build it
    def get_fake_vst_inputs(buffer_size_in_samples, n_frames_in_buffer):
        # 16 frames == 1 buffer

        audio           = tf.random.uniform([buffer_size_in_samples])
        loudness_db     = tf.random.uniform([n_frames_in_buffer])
        f0_hz           = tf.random.uniform([n_frames_in_buffer])
        midi            = tf.random.uniform([n_frames_in_buffer])
        onsets          = tf.random.uniform([n_frames_in_buffer])
        offsets         = tf.random.uniform([n_frames_in_buffer])
        instrument_id   = tf.random.uniform([1])

        return audio, loudness_db, f0_hz, midi, onsets, offsets, instrument_id

    vst_inputs = get_fake_vst_inputs(buffer_size_in_samples=1024,
                                    n_frames_in_buffer=1)

    model_vst(*vst_inputs, state=tf.random.uniform([512]))
    print('model_vst is built')
    
    # CONVERT
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)

    # Define the input signature
    
    # frame size is 64 here
    input_signature_offline = (
        tf.TensorSpec(shape=[64000], dtype=tf.float32),
        tf.TensorSpec(shape=[1000], dtype=tf.float32),
        tf.TensorSpec(shape=[1000], dtype=tf.float32),
        tf.TensorSpec(shape=[1000], dtype=tf.float32),
        tf.TensorSpec(shape=[1000], dtype=tf.float32),
        tf.TensorSpec(shape=[1000], dtype=tf.float32),
        tf.TensorSpec(shape=[1000], dtype=tf.float32),
        tf.TensorSpec(shape=[512], dtype=tf.float32),
    )
    
    input_signature_realtime = (
        tf.TensorSpec(shape=[1024], dtype=tf.float32),
        tf.TensorSpec(shape=[1], dtype=tf.float32),
        tf.TensorSpec(shape=[1], dtype=tf.float32),
        tf.TensorSpec(shape=[1], dtype=tf.float32),
        tf.TensorSpec(shape=[1], dtype=tf.float32),
        tf.TensorSpec(shape=[1], dtype=tf.float32),
        tf.TensorSpec(shape=[1], dtype=tf.float32),
        tf.TensorSpec(shape=[512], dtype=tf.float32),
    )

    # Wrap the model_vst's call method with the call_wrapper function
    
    def convert_and_save(target_model, save_path):
        call_fn = tf.function(model_vst.call, input_signature=input_signature_realtime)
        concrete_func = call_fn.get_concrete_function(*input_signature_realtime)

        # Convert the model_vst to TFLite
        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])

        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        converter.target_spec.supported_ops = [tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8, tf.lite.OpsSet.TFLITE_BUILTINS,tf.lite.OpsSet.SELECT_TF_OPS ]

        print('starting conversion..')
        tflite_model = converter.convert()
        print('conversion finished')

        import pathlib

        tflite_model_file = pathlib.Path(save_path)
        tflite_model_file.write_bytes(tflite_model)
        print(f'tflite model saved to {save_path}')
    
    model_vst.ae.run_synth_coder_only=False
    convert_and_save(model_vst, save_path=append_postfix_to_filename(path, '_midi'))
    
    model_vst.ae.run_synth_coder_only=True
    convert_and_save(model_vst, save_path=append_postfix_to_filename(path, '_synth'))
    
    ae_model.run_without_synths = orig_run_without_synths
    ae_model.run_inside_vst = orig_run_inside_vst
    ae_model.run_synth_coder_only = orig_run_synth_coder_only
    

In [None]:
def picke_train_history(epoch, step):
    import pickle

    train_history_path=f'{log_dir}/losses_history_e{epoch}_s{step}'
    eval_train_history_path=f'{log_dir}/eval_losses_history_e{epoch}_s{step}'

    with open(train_history_path, "wb") as fp:
        pickle.dump(losses_history, fp)

    with open(eval_train_history_path, "wb") as fp:
        pickle.dump(eval_losses_history, fp)

In [None]:
# Based on the original Magenta's train_synthesis_generator.py'

def train(training_data, training_data_length, training_epochs, start_epoch=1):
    
    """Training loop including evaluation."""
    start_time = time.time()
    loss_helper.reset_metrics()

    steps_in_epoch = int(training_data_length) # do not need to divide by hp.batch_size because each item is a batch itself
    logging.info(f"Starting training: epochs={training_epochs}, steps_in_epoch={steps_in_epoch}, start_epoch={start_epoch}")
    step = 0

    for epoch in range(start_epoch, training_epochs + start_epoch + 1):
        for i, data in enumerate(training_data):
            step = ((epoch-1) * steps_in_epoch + i) + 1
            print(f"training on epoch={epoch}, step={step}")

            # Run the model and get the loss.
            with tf.GradientTape() as tape, tf.GradientTape() as disc_tape:
                outputs = model(data, training=True, run_synth_coder_only=hp.run_synth_coder_only)

                loss_dict_recon = loss_helper.compute_loss(data, outputs,
                                                           synth_coder_only=hp.run_synth_coder_only,
                                                           add_synth_loss=hp.add_synth_loss)

                if not hp.run_synth_coder_only and hp.use_gan:
                    cond, real_outputs, fake_outputs = gan_loss_helper.get_disc_input(outputs)
                    D_fake = net_D([cond, fake_outputs])
                    D_real = net_D([cond, real_outputs])
                    loss_dict_disc = gan_loss_helper.compute_disc_loss(D_fake, D_real)
                    loss, loss_dict_gen = gan_loss_helper.compute_gen_loss(D_fake, D_real, loss_dict_recon['total_loss'])
                else:
                    loss = loss_dict_recon['total_loss']

            # Clip and apply gradients.
            grads = tape.gradient(loss, model.trainable_variables)
            grads, _ = tf.clip_by_global_norm(grads, hp.clip_grad)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            loss_helper.update_metrics(loss_dict_recon)
            loss_helper.write_summary(loss_dict_recon, writer, 'Train', step)
            
            losses_history.append(loss_dict_recon)
            plotlosses.update({'train_loss': loss_dict_recon['total_loss']})
            
            # Train discriminator and update GAN loss.
            if not hp.run_synth_coder_only and hp.use_gan:
                gradients_of_discriminator = disc_tape.gradient(loss_dict_disc['disc_loss'], net_D.trainable_variables)
                optimizer_disc.apply_gradients(zip(gradients_of_discriminator, net_D.trainable_variables))
                gan_loss_helper.update_metrics(loss_dict_disc)
                gan_loss_helper.write_summary(loss_dict_disc, writer, 'Train', step)
                gan_loss_helper.update_metrics(loss_dict_gen)
                gan_loss_helper.write_summary(loss_dict_gen, writer, 'Train', step)
            
            plotlosses.send()
            
        # Print logging summary.
        if epoch % hp.log_interval == 0:
            elapsed = time.time() - start_time
            current_lr = optimizer._decayed_lr('float32').numpy()
            msg = f'| {epoch:6d} epoch | lr {current_lr:02.2e} ' \
                  f'| ms/batch {(elapsed * 1000 / hp.log_interval):5.2f} '
            msg = msg + loss_helper.get_loss_log()
            loss_helper.reset_metrics()
            if not hp.run_synth_coder_only:
                msg = msg + gan_loss_helper.get_loss_log()
                gan_loss_helper.reset_metrics()
            logging.info(msg)

            # Print a message on the same line
            IPython.display.clear_output()
            print(msg)

            start_time = time.time()

        # Evaluate.
        if epoch % hp.eval_interval == 0:
            eval_start_time = time.time()
            
            eval_loss_dict = evaluate(evaluation_data, epoch=epoch, step=step)

            plotlosses.update({'eval_loss_total': eval_loss_dict['total_loss'],
                              'eval_loss_midi': eval_loss_dict['loss_spectral_midi']})
            eval_losses_history.append(eval_loss_dict)
            
            # Synthesize training data.
            outputs = model(train_sample_batch, training=True, run_synth_coder_only=hp.run_synth_coder_only)
            
            save_results(outputs['synth_audio'], train_sample_batch['audio'], log_dir, f'train_{epoch}_synth', hp.sample_rate)
            if 'midi_audio' in outputs.keys():
                save_results(outputs['midi_audio'], train_sample_batch['audio'], log_dir, f'train_{epoch}_midi', hp.sample_rate)
            if hp.write_tfrecord_audio:
                write_tensorboard_audio(writer, train_sample_batch, outputs, epoch, tag='Train')

            # Synthesize evaluation data.
            outputs = model(eval_sample_batch, training=False, run_synth_coder_only=hp.run_synth_coder_only)
            save_results(outputs['synth_audio'], eval_sample_batch['audio'], log_dir, f'eval_{epoch}_synth', hp.sample_rate)
            if 'midi_audio' in outputs.keys():
                save_results(outputs['midi_audio'], eval_sample_batch['audio'], log_dir, f'eval_{epoch}_midi', hp.sample_rate)
            if hp.write_tfrecord_audio:
                write_tensorboard_audio(writer, eval_sample_batch, outputs, epoch, tag='Eval')
            
            eval_time_elapsed = time.time() - eval_start_time
            logging.info(f"Evaluation took {eval_time_elapsed*1000}ms")
        
        # DDSP Inference training finished.
        # Start training Synthesis Generator and dump dataset for expression generator.
        if (epoch - start_epoch + 1) >= hp.synth_coder_training_epochs:
            hp.run_synth_coder_only = False
            if not hp.add_synth_loss:
                model.freeze_synth_coder()

        # Save weights for the whole model.
        if epoch % hp.checkpoint_save_interval == 0:
            model.save_weights(f'{log_dir}/e{epoch}_s{step}')
            try:
                export_to_tflite(ae_model=model, path=f'{log_dir}/e{epoch}_s{step}.tflite')
            except Exception as e:
                logging.error(traceback.format_exc())
                print("An exception occurred while exporting to tflite")
            
        picke_train_history(epoch=epoch, step=step)
        plotlosses.send()


def evaluate(evaluation_data, epoch, step):
    """Evaluating the test set."""
    eval_loss_helper = ReconLossHelper(hp, eval_recon_loss=True)
    start_time = time.time()
    for data in evaluation_data:
        outputs = model(data, training=False, run_synth_coder_only=hp.run_synth_coder_only)

        loss_dict = eval_loss_helper.compute_loss(data, outputs,synth_coder_only=hp.run_synth_coder_only)
        eval_loss_helper.update_metrics(loss_dict)

    eval_loss_helper.write_mean_summary(writer, 'Eval', step)
    msg = f'eval: | epoch {epoch:6d} | eval time: {(time.time() - start_time):3.3f}'
    msg = msg + eval_loss_helper.get_loss_log()
    logging.info(msg)
    
    return eval_loss_helper.get_loss_dict()


In [None]:
hp.data_dir = '../data/'
hp.multi_instrument = True
hp.instrument = 'vn'

#hp.data_dir = '/scratch/ssd004/scratch/burakovr/midi_ddsp_urmp_dataset/all_instruments'
#hp.multi_instrument = True
#hp.instrument = 'all'

hp.vst_inference_mode = False
hp.train_synth_coder_first = True
hp.training_epochs =96 # 5k steps
hp.log_interval = 1
hp.checkpoint_save_interval = 1
hp.eval_interval = 1
hp.synth_coder_training_epochs = 32
#hp.synth_coder_training_epochs = 15

hp.batch_size=16
#hp.reverb_length = 16000
hp.reverb_length = 48000

#experiment_name = "full_all_train_and_export"
experiment_name = "full_vn_train_and_export_batchsize_16_3"

In [None]:
# From original Magenta's train_synthesis_generator.py'

# Load model, create log directory and log file.
log_dir = f"logs/logs_synthesis_generator/{experiment_name}"
print(log_dir)

In [None]:
# From original Magenta's train_synthesis_generator.py'

tf.get_logger().setLevel('INFO')

In [None]:
# From original Magenta's train_synthesis_generator.py'

writer = tf.summary.create_file_writer(log_dir)
log_path = os.path.join(log_dir, 'train2.log')
logging.basicConfig(level=logging.INFO,
                  format='%(asctime)s - %(levelname)s: %(message)s',
                  handlers=[
                    logging.FileHandler(log_path),
                    logging.StreamHandler(sys.stdout)]
                  )

In [None]:
latest_epoch_path = tf.train.latest_checkpoint(f'{log_dir}/')
latest_epoch_weights_name = os.path.basename(latest_epoch_path) if latest_epoch_path else None

hp.restore_path=latest_epoch_path
print(f'hp.restore_path={hp.restore_path}')
print(f'latest_epoch_weights_name={latest_epoch_weights_name}')

In [None]:
# Load dataset.
training_data, length_training_data, evaluation_data, length_evaluation_data = get_dataset(hp, training_data_repeats=1)

# Filter the data so that we don't use values we don't need

# do not keep 'f0_confidence', 'note_active_frame_indices', 'note_active_velocities', 'note_offsets', 'note_onsets', 'power_db',  'recording_id',
keys_to_keep = ['audio',
                'f0_hz',
                'instrument_id',
                'loudness_db',
                'midi',
                'onsets',
                'offsets']
# TODO: What to do with 'mel'? Is it used?

def filter_keys_in_dataset(example):
    return {key: example[key] for key in keys_to_keep}

training_data = training_data.map(filter_keys_in_dataset)
evaluation_data = evaluation_data.map(filter_keys_in_dataset)

# For optional debugging purposes
training_data = training_data#.take(10)
evaluation_data = evaluation_data#.take(10)

eval_sample_batch = next(iter(evaluation_data))
train_sample_batch = next(iter(training_data))
logging.info('Data loaded! Data size: %s', str(length_training_data))

In [None]:
from midi_ddsp.modules.model.model_vst import MIDIExpressionAE_VST_IO_Wrapper

#hp.use_mel = False
hp.use_mel = True

# Create Synthesis Generator
model = get_synthesis_generator(hp)

In [None]:
logging.info(f"model.run_synth_coder_only={model.run_synth_coder_only},\n"
      f"model.run_without_synths={model.run_without_synths},\n"
      f"model.run_inside_vst={model.run_inside_vst}")

In [None]:
model._build(get_fake_data_synthesis_generator(hp))

In [None]:
print(tf.keras.Model.summary(model, expand_nested=True))

In [None]:
if hp.restore_path:
    print(f'restoring from {hp.restore_path}')
    model.load_weights(hp.restore_path)
    log_dir = os.path.dirname(hp.restore_path)

In [None]:
oo = model(train_sample_batch)
play(oo['synth_audio'])
play(oo['midi_audio'])

In [None]:
# From original Magenta's train_synthesis_generator.py'

# Create optimizer, loss helper and discriminator.
scheduler = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=hp.lr, decay_steps=1000, decay_rate=0.99)
optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=scheduler)

loss_helper = ReconLossHelper(hp)
gan_loss_helper = GANLossHelper(lambda_recon=hp.lambda_recon, lambda_G=hp.lambda_G, sg_z=hp.sg_z)
optimizer_disc = tf.keras.optimizers.Adam(learning_rate=hp.lr_disc)

net_D = Discriminator(nhid=hp.discriminator_dim)
# 64=instrument_emb_dim
z_dim = hp.discriminator_dim + int(hp.multi_instrument) * 64
# synth_params_dim = dim(nharmonic + nnoise + amplitude + f0)
synth_params_dim = hp.nhramonic + hp.nnoise + 2
_ = net_D((tf.random.normal([4, 1000, z_dim]), tf.random.normal([4, 1000, synth_params_dim])))


In [None]:
print(log_path)

In [None]:
from midi_ddsp.utils.inference_utils import get_process_group

my_processor_group = get_process_group(model.n_frames, model.frame_size, model.sample_rate, use_angular_cumsum=False)

In [None]:
from tensorflow.python.client import device_lib

device_lib.list_local_devices()

In [None]:
# Print model summary and hyperparameters.
model.summary(print_fn=logging.info)
logging.info(str(print_hparams(hp)))

In [None]:
#model_vst = MIDIExpressionAE_VST_IO_Wrapper(ae_model=model, vst_buffer_size=1024, vst_frame_size=1024)
#model_vst = MIDIExpressionAE_VST_IO_Wrapper(ae_model=model, vst_buffer_size=64000, vst_frame_size=64)

In [None]:
# Restore losses so that we could see them

import pickle

plotlosses = PlotLosses()

print(latest_epoch_weights_name)

# Open the file in binary read mode
with open(f'{log_dir}/losses_history_{latest_epoch_weights_name}', 'rb') as file:
    # Load the object from the file
    losses_history = pickle.load(file)

with open(f'{log_dir}/eval_losses_history_{latest_epoch_weights_name}', 'rb') as file:
    # Load the object from the file
    eval_losses_history = pickle.load(file) 

    
approx_steps_in_epoch = len(losses_history) // len(eval_losses_history)
num_epochs = len(eval_losses_history)
num_steps = len(losses_history)
print(f'num_epochs={num_epochs}')
print(f'num_steps={num_steps}')
print(f'approx_steps_in_epoch={approx_steps_in_epoch}')

for i in range(len(losses_history)):  
    #plotlosses.update({'eval_loss': eval_losses_history[i]['total_loss'], 
    #                   'train_loss': losses_history[i]['total_loss']})
    if i != 0 and i % approx_steps_in_epoch == 0:
        epoch = i // approx_steps_in_epoch - 1
        print(i, epoch)
        plotlosses.update({'eval_loss_total': eval_losses_history[epoch]['total_loss'],
                          'eval_loss_midi': eval_losses_history[epoch]['loss_spectral_midi']})

    plotlosses.update({'train_loss': losses_history[i]['total_loss']})
    
plotlosses.send()

In [None]:
# Start training loop

if hp.restore_path:
    bname = os.path.basename(hp.restore_path)
    parts = bname.split(sep='_')
    start_epoch = int(parts[0][1:])
    
    if hp.train_synth_coder_first is True and start_epoch > hp.synth_coder_training_epochs:
        hp.train_synth_coder_first = False
else:
    start_epoch = 1

#start_epoch = 16
#hp.train_synth_coder_first = False
    
print(start_epoch)
print(hp.train_synth_coder_first)

if hp.mode == 'train':
    if hp.train_synth_coder_first:
      hp.run_synth_coder_only = True
      model.train_synth_coder_only()
    else:
      hp.run_synth_coder_only = False
      model.freeze_synth_coder()

    train(training_data=training_data, training_data_length=length_training_data, training_epochs=hp.training_epochs, start_epoch=start_epoch)

elif hp.mode == 'eval':
    hp.run_synth_coder_only = False
    evaluate(evaluation_data, epoch=start_epoch, step=int(start_epoch/length_training_data))

In [None]:
import matplotlib.pyplot as plt

total_losses = [item['total_loss'] for item in losses_history]

plt.figure(figsize=(40,10))
plt.locator_params(nbins=30)

ax = plt.gca()
ax.set_ylim([0, 20])

plt.plot(total_losses, 'g', label='Training loss')
plt.title('Training loss')
plt.xlabel('Steps')
plt.ylabel('Total Loss')
plt.legend()
plt.show()
ax

#plt.savefig(f'{log_dir}/loss.png')

In [None]:
import matplotlib.pyplot as plt

total_losses = [item['total_loss'] for item in eval_losses_history]
#total_losses = [1, 2, 5]

plt.figure(figsize=(40,10))
plt.locator_params(nbins=30)

ax = plt.gca()
ax.set_ylim([0, 20])

plt.plot(total_losses, 'g', label='Validation loss')
plt.title('Validation loss')
plt.xlabel('Epochs')
plt.ylabel('Total Loss')
plt.legend()
plt.show()

plt.savefig(f'{log_dir}/val_loss.png')

In [None]:
import pickle

train_history_path=f'{log_dir}/losses_history'
eval_train_history_path=f'{log_dir}/eval_losses_history'

with open(train_history_path, "wb") as fp:
    pickle.dump(losses_history, fp)

with open(eval_train_history_path, "wb") as fp:
    pickle.dump(eval_losses_history, fp)

In [None]:
input_audios = next(iter(training_data))

outputs = model(input_audios)

In [None]:
for i in range(input_audios['audio'].shape[0]):
    print("Original audio: ")
    play(input_audios['audio'][i])
    specplot(input_audios['audio'][i])

    print("Reconstructed audio: ")
    play(outputs['synth_audio'][i])
    specplot(outputs['synth_audio'][i])