# Imports

In [None]:
# Importing our custom ddsp lib and local sparesenet lib

import os
import platform
import sys

if platform.system() == 'Windows':
    midi_ddsp_module_path = os.path.abspath(os.path.join('../../'))
    ddsp_module_path = os.path.abspath(os.path.join('../../../ddsp-playground-2/'))
    original_midi_ddsp_module_path = 'E:/Code/Projects/TimbreTransfer/original-midi-ddsp/'
    original_ddsp_module_path = 'E:/Code/Projects/TimbreTransfer/original-ddsp-for-vst-debugging/'
else:
    midi_ddsp_module_path = os.path.abspath(os.path.join('../../'))
    ddsp_module_path = os.path.abspath(os.path.join('../../../ddsp/ddsp-playground-2/'))
    original_midi_ddsp_module_path = None
    original_ddsp_module_path = None

def apply_module_path(module_path):
    print(f"module_path={module_path}")
    if module_path not in sys.path:
      sys.path.append(module_path)
      print(f"appending {module_path} to sys.path")
    else:
      print(f"do not appending {module_path} to sys.path")

apply_module_path(midi_ddsp_module_path)
apply_module_path(ddsp_module_path)
apply_module_path(original_midi_ddsp_module_path)
apply_module_path(original_ddsp_module_path)

import sys
if platform.system() != 'Windows':
    sparsenet_module_path_abs = '/ssd003/home/burakovr/projects/vova/envs/main/lib/python3.8/site-packages/'
    apply_module_path(sparsenet_module_path_abs)

import midi_ddsp

In [None]:
#  Copyright 2022 The MIDI-DDSP Authors.
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#      http://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

"""Training code for Synthesis Generator."""

import tensorflow as tf
import time
import os
import sys
import logging
import argparse
import IPython

from keras.utils.layer_utils import print_summary
from livelossplot import PlotLosses

import original_midi_ddsp

from midi_ddsp.data_handling.get_dataset import get_dataset
from midi_ddsp.utils.training_utils import print_hparams, set_seed, \
    save_results, str2bool
from midi_ddsp.utils.summary_utils import write_tensorboard_audio
#                          from midi_ddsp.hparams_synthesis_generator import hparams as hp
from midi_ddsp.hparams_synthesis_generator import hparams_debug as hp
from midi_ddsp.modules.recon_loss import ReconLossHelper
from midi_ddsp.modules.gan_loss import GANLossHelper
from midi_ddsp.modules.get_synthesis_generator import get_synthesis_generator, \
    get_fake_data_synthesis_generator
from midi_ddsp.modules.discriminator import Discriminator

from ddsp.colab.notebook_utils import play, specplot

# Dataset

### Vn dataset

In [None]:
hp.data_dir = '../../data/'
hp.batch_size = 1 # for saving to TFRecord purposes

hp.multi_instrument = False
hp.instrument = 'vn'

In [None]:
# Load vn dataset.
vn_training_data, vn_length_training_data, vn_evaluation_data, vn_length_evaluation_data = get_dataset(hp, training_data_repeats=1)

In [None]:
vn_eval_sample_batch = next(iter(vn_evaluation_data))
vn_train_sample_batch = next(iter(vn_training_data))
logging.info('Violin data loaded! Data size: %s', str(vn_length_training_data))

In [None]:
vn_training_example = next(iter(vn_training_data))
play(vn_training_example['audio'])

In [None]:
print(vn_length_training_data)

### Cl dataset

In [None]:
hp.data_dir = '../../data_cl/'
hp.instrument = 'cl'

In [None]:
# Load vn dataset.
cl_training_data, cl_length_training_data, cl_evaluation_data, cl_length_evaluation_data = get_dataset(hp, training_data_repeats=1)

In [None]:
cl_eval_sample_batch = next(iter(cl_evaluation_data))
cl_train_sample_batch = next(iter(cl_training_data))
logging.info('Clarinet data loaded! Data size: %s', str(cl_length_training_data))

In [None]:
cl_training_example = next(iter(cl_training_data))
play(cl_training_example['audio'])

### All dataset

In [None]:
hp.data_dir = '../../data_all/'
hp.instrument = 'all'
hp.multi_instrument = True

In [None]:
# Load vn dataset.
all_training_data, all_length_training_data, all_evaluation_data, all_length_evaluation_data = get_dataset(hp, training_data_repeats=1)

In [None]:
all_eval_sample_batch = next(iter(all_evaluation_data))
all_train_sample_batch = next(iter(all_training_data))
logging.info('All (multi-instrument) data loaded! Data size: %s', str(all_length_training_data))

In [None]:
all_training_example = next(iter(all_training_data))
play(all_training_example['audio'])
print(all_training_example['instrument_id'])

In [None]:
print(all_length_evaluation_data)

# Create selected eval dataset

In [None]:
def list_selected_eval_dataset(ds, instrument_abb):

    all_eval_example = ds.take(40)
    batch_size = 1

    all_eval_examples_first_20 = []
    for i, ex in enumerate(all_eval_example):
        for j in range(batch_size):
            
            if instrument_abb == original_midi_ddsp.data_handling.instrument_name_utils.INST_ID_TO_ABB_DICT[int(ex['instrument_id'][j])]:
                print(f'ex#{i} i#{j} = #{i*batch_size+j}')
                play(ex['audio'][j])
                print(original_midi_ddsp.data_handling.instrument_name_utils.INST_ID_TO_NAME_DICT[int(ex['instrument_id'][j])])
                print()

                all_eval_examples_first_20.append({k: v[j] for k, v in ex.items()})
    
    
    print('\n\n\n listing gathered samples:')
    for i, ex in enumerate(all_eval_examples_first_20):
        print(f'#{i}')
        play(ex['audio'])
        print()
        
    return all_eval_examples_first_20
    
listed_examples = list_selected_eval_dataset(ds=all_training_data, instrument_abb='vn')
    
#selected_audio_examples.append(listed_examples[49])
#print(len(selected_audio_examples))

In [None]:
selected_audio_examples = {}

# Serialize selected eval dataset

In [None]:
def serialize_example(audio, f0_confidence, f0_hz, instrument_id, loudness_db,
                      note_active_frame_indices, note_active_velocities, note_offsets,
                      note_onsets, power_db, recording_id, midi, onsets, offsets):
    print(recording_id)
    
    features = {
        'audio': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(audio).numpy()])),
        'f0_confidence': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(f0_confidence).numpy()])),
        'f0_hz': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(f0_hz).numpy()])),
        'instrument_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(instrument_id).numpy()])),
        'loudness_db': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(loudness_db).numpy()])),
        'note_active_frame_indices': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(note_active_frame_indices).numpy()])),
        'note_active_velocities': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(note_active_velocities).numpy()])),
        'note_offsets': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(note_offsets).numpy()])),
        'note_onsets': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(note_onsets).numpy()])),
        'power_db': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(power_db).numpy()])),
        'recording_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(recording_id).numpy()])),
        'midi': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(midi).numpy()])),
        'onsets': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(onsets).numpy()])),
        'offsets': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(offsets).numpy()]))
    }

    example_proto = tf.train.Example(features=tf.train.Features(feature=features))
    return example_proto.SerializeToString()

In [None]:
def tf_serialize_example(sample):
    
    print({k: v.shape for k, v in sample.items()})
    
    tf_string = tf.py_function(
        serialize_example,
        inp=list(sample.values()),
        Tout=tf.string,
    )
    return tf.reshape(tf_string, ())

In [None]:
def serialize(filename):
    import pandas as pd
    selected_audio_examples_ds = tf.data.Dataset.from_tensor_slices(pd.DataFrame.from_dict(selected_audio_examples).to_dict(orient="list"))

    for i, ex in enumerate(selected_audio_examples_ds):
        print(i)
        play(ex['audio'])
        print()
        
    serialized_selected_audio_examples_ds = selected_audio_examples_ds.map(tf_serialize_example)
    
    with tf.io.TFRecordWriter(filename) as writer:
        for serialized_example in serialized_selected_audio_examples_ds:
            writer.write(serialized_example.numpy())

In [None]:
selected_eval_dataset_filename = "selected_srcs_for_timbre_transfer.tfrecord"

# Load selected eval dataset

In [None]:
feature_datatypes = {k: v.dtype for k, v in all_training_example.items()}

In [None]:
def parse_example(example_proto):
    feature_description = {
        'audio': tf.io.FixedLenFeature([], dtype=tf.string),
        'f0_confidence': tf.io.FixedLenFeature([], dtype=tf.string),
        'f0_hz': tf.io.FixedLenFeature([], dtype=tf.string),
        'instrument_id': tf.io.FixedLenFeature([], dtype=tf.string),
        'loudness_db': tf.io.FixedLenFeature([], dtype=tf.string),
        'note_active_frame_indices': tf.io.FixedLenFeature([], dtype=tf.string),
        'note_active_velocities': tf.io.FixedLenFeature([], dtype=tf.string),
        'note_offsets': tf.io.FixedLenFeature([], dtype=tf.string),
        'note_onsets': tf.io.FixedLenFeature([], dtype=tf.string),
        'power_db': tf.io.FixedLenFeature([], dtype=tf.string),
        'recording_id': tf.io.FixedLenFeature([], dtype=tf.string),
        'midi': tf.io.FixedLenFeature([], dtype=tf.string),
        'onsets': tf.io.FixedLenFeature([], dtype=tf.string),
        'offsets': tf.io.FixedLenFeature([], dtype=tf.string),
    }
    parsed_example = tf.io.parse_single_example(example_proto, feature_description)

    for key in parsed_example:
        parsed_example[key] = tf.io.parse_tensor(parsed_example[key], out_type=feature_datatypes[key])

    # parsed_example["f0_hz"] =            tf.reshape(parsed_example["f0_hz"], ),
    # parsed_example["loudness_db"] =      parsed_example["loudness_db"],
    # parsed_example["synth_amplitudes"] = parsed_example["synth_amplitudes"],
    # parsed_example["onsets"] =           parsed_example["onsets"],
    # parsed_example["offsets"] =          parsed_example["offsets"],

    return parsed_example

In [None]:
def load_selected_eval_dataset(path):

    serialized_selected_audio_examples_ds_restored = tf.data.TFRecordDataset(filenames=[path]).map(parse_example)

    
    for i, ex in enumerate(serialized_selected_audio_examples_ds_restored):
        print(i)
        play(ex['audio'])
        print()
    
    return serialized_selected_audio_examples_ds_restored

In [None]:
selected_eval_dataset = load_selected_eval_dataset(selected_eval_dataset_filename).batch(17)

# Models loading

In [None]:
def get_midi_ddsp_model(hp, path, example_input, use_original_model):
    if use_original_model:
        midi_ddsp_model = original_midi_ddsp.modules.get_synthesis_generator.get_synthesis_generator(hp)
    else:
        midi_ddsp_model = get_synthesis_generator(hp)

    midi_ddsp_model(example_input)
    midi_ddsp_model.summary()

    midi_ddsp_model.load_weights(path)
    synthcoder_model = midi_ddsp_model.synth_coder

    tf.keras.Model.summary(midi_ddsp_model, expand_nested=False)
    tf.keras.Model.summary(synthcoder_model, expand_nested=False)

    return synthcoder_model, midi_ddsp_model

In [None]:
midi_ddsp_models_dir = 'E:/Code/TimbreTransfer_ExperimentExamples/MIDI_DDSP/models/'

### Magenta all

In [None]:
hp.multi_instrument=True
hp.instrument='vn'

hp.reverb_length = 48000

magenta_synthcoder_all, magenta_midi_ddsp_all = get_midi_ddsp_model(
                                    hp=hp,
                                    path=os.path.join(midi_ddsp_models_dir, 'magenta_all_full_model/50000'),
                                    example_input=all_train_sample_batch,
                                    use_original_model=True)

oo = magenta_midi_ddsp_all(vn_training_example)
play(oo['synth_audio'])
play(oo['midi_audio'])

### My vn

In [None]:
hp.multi_instrument=True
hp.instrument='vn'
#hp.batch_size = 1

hp.reverb_length = 48000
hp.use_mel=True

my_synthcoder_vn, my_midi_ddsp_vn = get_midi_ddsp_model(
                                    hp=hp,
                                    path=os.path.join(midi_ddsp_models_dir, 'my_vn_full_model/e235_s28669'),
                                    example_input=all_train_sample_batch,
                                    use_original_model=False)

oo = my_midi_ddsp_vn(vn_training_example)
play(oo['synth_audio'])
play(oo['midi_audio'])

### My all

In [None]:
hp.multi_instrument=True
hp.instrument='all'
#hp.batch_size = 1

hp.reverb_length = 16000

my_synthcoder_all, my_midi_ddsp_all = get_midi_ddsp_model(
                                    hp=hp,
                                    path=os.path.join(midi_ddsp_models_dir, 'my_all_full_model/e24_s23975'),
                                    example_input=all_train_sample_batch,
                                    use_original_model=False)

oo = my_midi_ddsp_all(vn_training_example)
play(oo['synth_audio'])
play(oo['midi_audio'])

### Synthcoder vn low

In [None]:
hp.multi_instrument=False
hp.instrument='vn'

hp.reverb_length = 16000

synthcoder_vn_low, midi_ddsp_vn_low = get_midi_ddsp_model(hp=hp, path='logs/logs_synthesis_generator/full_vn_train_and_export/e10_s1939', example_input=vn_eval_sample_batch, use_original_model=False)

oo = midi_ddsp_vn_low(vn_training_example)
play(oo['synth_audio'])
play(oo['midi_audio'])

### Synthcoder vn high

In [None]:
hp.multi_instrument=False
hp.instrument='vn'

hp.reverb_length = 16000

#synthcoder_vn_high, midi_ddsp_vn_high = get_midi_ddsp_model(hp=hp, path='logs/logs_synthesis_generator/full_vn_train_and_export/e31_s6013', example_input=vn_eval_sample_batch, use_original_model=False)
synthcoder_all_high, midi_ddsp_all_high = magenta_synthcoder_all, magenta_midi_ddsp_all

oo = midi_ddsp_all_high(vn_training_example)
play(oo['synth_audio'])
play(oo['midi_audio'])

### SynthCoder ablation (without mel)

In [None]:
hp.multi_instrument=True
hp.instrument='all'

hp.reverb_length = 16000
hp.use_mel = False

synthcoder_no_mel_all, midi_ddsp_no_mel_all = get_midi_ddsp_model(hp=hp, 
                                                                  path=os.path.join(midi_ddsp_models_dir, 'ablation_study_vn_synthcoder_witout_mel/e300_s249599'), 
                                                                  example_input=vn_eval_sample_batch, 
                                                                  use_original_model=False)

oo = midi_ddsp_no_mel_all(vn_training_example)
play(oo['synth_audio'])
play(oo['midi_audio'])

### SynthCoder ablation (without mel) weird

In [None]:
hp.multi_instrument=True
hp.instrument='all'

hp.reverb_length = 48000
hp.use_mel = False

synthcoder_no_mel_all_weird, midi_ddsp_no_mel_all_weird = get_midi_ddsp_model(hp=hp, 
                                                                  path=os.path.join(midi_ddsp_models_dir, 'weird_synthcoder_without_mel/e16_s9983'), 
                                                                  example_input=vn_eval_sample_batch, 
                                                                  use_original_model=False)

oo = midi_ddsp_no_mel_all_weird(vn_training_example)
play(oo['synth_audio'])
play(oo['midi_audio'])

In [None]:
play(magenta_midi_ddsp_all(all_training_example)['midi_audio'])

# Eval

In [None]:
def get_audio(synthcoder_output, processor_group):
    my_control_params = processor_group.get_controls(synthcoder_output, verbose=False)
    my_synth_audio = processor_group.get_signal(my_control_params)
    return my_synth_audio

def downsample_tensor(tensor, factor):
    if len(tensor.shape) == 3:
        # tensor is 3D
        batch_size, time_length, n = tensor.shape
        assert time_length % factor == 0, "Time axis length must be divisible by the downsample factor."

        # Reshape the tensor to prepare for downsampling
        tensor_reshaped = tf.reshape(tensor, [batch_size, time_length // factor, factor, n])

        # Take the mean along the new axis, which was originally the time axis
        downsampled_tensor = tf.reduce_mean(tensor_reshaped, axis=2)
    elif len(tensor.shape) == 2:
        # tensor is 2D
        batch_size, time_length = tensor.shape
        assert time_length % factor == 0, "Time axis length must be divisible by the downsample factor."

        # Reshape the tensor to prepare for downsampling
        tensor_reshaped = tf.reshape(tensor, [batch_size, time_length // factor, factor])

        # Take the mean along the new axis, which was originally the time axis
        downsampled_tensor = tf.reduce_mean(tensor_reshaped, axis=2)
    else:
        return tensor

    return downsampled_tensor

def downsample_synth_params(synth_params: dict, factor):
    if factor > 1:
        return {k: downsample_tensor(v, factor) for k, v in synth_params.items()}
    else:
        return synth_params

def downsample_inputs(inputs: dict, factor):
    if factor > 1:
        return {k: downsample_tensor(v, factor) if k != 'audio' else v for k, v in inputs.items()}
    else:
        return inputs

In [None]:
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio

def plot_tensor_3d_as_image(tensor, title):
    # Ensure the tensor is a NumPy array
    tensor_np = tensor.numpy()

    # Combine the batch_size and n dimensions into a single dimension
    combined_shape = (tensor_np.shape[0] * tensor_np.shape[2], tensor_np.shape[1])
    tensor_2d = tensor_np.reshape(combined_shape)

    plot_tensor_2d_as_image(tensor=tensor_2d, title=title)

def plot_tensor_2d_as_image(tensor, title, savepath=None, saveprefix=None, showfig=True):
    # Create a heatmap using Plotly
    fig = go.Figure(go.Heatmap(z=tf.transpose(tensor), colorscale='Magma'))

    # Customize the axis labels and title
    fig.update_layout(
        xaxis_title="Samples",
        yaxis_title="",
        title=title,
        width=32*37,  # Width in pixels (32 inches * 37 pixels per inch)
        height=16*37,  # Height in pixels (10 inches * 37 pixels per inch)
        autosize=False,
        margin=dict(l=0, r=0, t=50, b=0)
    )

    if showfig:
        # Display the visualization
        fig.show()
    
    if savepath:
                
        if not os.path.exists(savepath):
            # Create the folder if it doesn't exist
            os.makedirs(savepath)
            
        pio.write_image(fig, os.path.join(savepath, f'{saveprefix}_{title}.png'))
        
def line_plot_tensor_2d(tensor, title, yaxis_title="", max_yaxis=None, savepath=None, saveprefix=None, showfig=True):
    # Convert the tensor to a 1D array
    fundamental_frequencies = tf.squeeze(tensor)

    # Create the time axis
    time_axis = np.arange(0, len(fundamental_frequencies))

    # Create a line plot using Plotly
    fig = go.Figure(go.Scatter(x=time_axis, y=fundamental_frequencies, mode='lines'))

    # Customize the axis labels and title
    fig.update_layout(
        xaxis_title="Frames",
        yaxis_title=yaxis_title,
        title=title,
        width=32*37,  # Width in pixels (32 inches * 37 pixels per inch)
        height=16*37,  # Height in pixels (10 inches * 37 pixels per inch)
        autosize=False,
        margin=dict(l=0, r=0, t=50, b=0)
    )

    # Set the maximum value for the y-axis
    if max_yaxis is not None:
        fig.update_layout(yaxis=dict(range=[0, max_yaxis]))
    
    # Display the visualization
    if showfig:
        # Display the visualization
        fig.show()
    
    if savepath:
        
        if not os.path.exists(savepath):
            # Create the folder if it doesn't exist
            os.makedirs(savepath)
        
        pio.write_image(fig, os.path.join(savepath, f'{saveprefix}_{title}.png'))

In [None]:
import librosa
from copy import deepcopy
from pydub import AudioSegment


def read_wav(filename, dir='../../data_synths', target_sample_rate = 16000):
    audio_file = os.path.join(dir, filename)
    audio_data, original_sample_rate = librosa.load(audio_file, sr=None)

    resampled_audio_data = librosa.resample(audio_data, original_sample_rate, target_sample_rate)
    #resampled_audio_data = audio_data

    # Convert the resampled audio data to a TensorFlow tensor
    resampled_audio_tensor = tf.convert_to_tensor(resampled_audio_data, dtype=tf.float32)

    return resampled_audio_tensor


def save_wav(audio_tensor, filename, dir, target_sample_rate=16000):
    
    midi_ddsp.utils.audio_io.save_wav(wav=audio_tensor, path=os.path.join(dir, filename), sample_rate=target_sample_rate)

In [None]:
from midi_ddsp.data_handling.instrument_name_utils import INST_ID_TO_NAME_DICT
from midi_ddsp.utils.audio_io import tf_log_mel

def play_results(inputs, outputs, experiment_name,
                 show_midi_ddsp=True,
                 plot_inputs=False,
                 plot_synthcoder=False,
                 plot_midi_ddsp=False,
                 plot_io_diff=False,
                 plot_expression_features=False,
                 fig_savepath=None,
                 showfig=True):

    for batch_idx in range(outputs['synth_audio'].shape[0]):
        
        pref=f'{batch_idx}'
        
        instrument_name = INST_ID_TO_NAME_DICT[int(inputs["instrument_id"][batch_idx])]
        
        if "instrument_id_gt" in inputs is not None:
            instrument_name_gt = INST_ID_TO_NAME_DICT[int(inputs["instrument_id_gt"][batch_idx])]
        else:
            instrument_name_gt = instrument_name
        
        input_audio = inputs['audio'][batch_idx]
        input_mel = None
        if plot_inputs:
            input_mel = tf_log_mel(audio=input_audio,
                 sample_rate=16000,
                 win_length=64,
                 hop_length=64,
                 n_fft=1024,
                 num_mels=64,
                 fmin=40,
                 pad_end=True)
            plot_tensor_2d_as_image(input_mel, title='Input audio', savepath=fig_savepath, saveprefix=pref, showfig=showfig)
            line_plot_tensor_2d(inputs['f0_hz'][0], title='f0', savepath=fig_savepath, saveprefix=pref, showfig=showfig)

        print(f'{experiment_name}: instrument={instrument_name_gt}: ground truth: ')
        play(input_audio)
        
        synthcoder_audio = outputs['synth_audio'][batch_idx]
        synthcoder_mel = None
        if plot_synthcoder:
            synthcoder_mel = tf_log_mel(audio=synthcoder_audio,
                 sample_rate=16000,
                 win_length=64,
                 hop_length=64,
                 n_fft=1024,
                 num_mels=64,
                 fmin=40,
                 pad_end=True)
                
            plot_tensor_2d_as_image(synthcoder_mel, title='SynthCoder output audio', savepath=fig_savepath, saveprefix=pref, showfig=showfig)
            plot_tensor_2d_as_image(outputs['synth_params']['noise_magnitudes'][batch_idx], title='SynthCoder output noise magnitudes', savepath=fig_savepath, saveprefix=pref, showfig=showfig)
            plot_tensor_2d_as_image(outputs['synth_params']['harmonic_distribution'][batch_idx], title='SynthCoder output harmonic distribution', savepath=fig_savepath, saveprefix=pref, showfig=showfig)

        if plot_io_diff:
            spectrogram_difference = tf.abs(synthcoder_mel - input_mel)
            plot_tensor_2d_as_image(spectrogram_difference, title='SynthCoder IO diff', savepath=fig_savepath, saveprefix=pref, showfig=showfig)

        print(f'{experiment_name}: instrument={instrument_name}: synth_audio: ')
        play(synthcoder_audio)


        if show_midi_ddsp:

            if plot_expression_features:
                line_plot_tensor_2d(outputs['conditioning_dict']['attack'][batch_idx], title='Attack', max_yaxis=1, savepath=fig_savepath, saveprefix=pref, showfig=showfig)
                line_plot_tensor_2d(outputs['conditioning_dict']['vibrato'][batch_idx], title='Vibrato', max_yaxis=1, savepath=fig_savepath, saveprefix=pref, showfig=showfig)
                line_plot_tensor_2d(outputs['conditioning_dict']['volume'][batch_idx], title='Volume', max_yaxis=1, savepath=fig_savepath, saveprefix=pref, showfig=showfig)

            midi_audio = outputs['midi_audio'][batch_idx]
            midi_mel = None
            if plot_midi_ddsp:
                midi_mel = tf_log_mel(audio=midi_audio,
                     sample_rate=16000,
                     win_length=64,
                     hop_length=64,
                     n_fft=1024,
                     num_mels=64,
                     fmin=40,
                     pad_end=True)
                plot_tensor_2d_as_image(midi_mel, title='ExpressionDecoder output audio', savepath=fig_savepath, showfig=showfig)
                plot_tensor_2d_as_image(outputs['midi_synth_params']['noise_magnitudes'][batch_idx], title='ExpressionDecoder output noise magnitudes', savepath=fig_savepath, saveprefix=pref, showfig=showfig)
                plot_tensor_2d_as_image(outputs['midi_synth_params']['harmonic_distribution'][batch_idx], title='ExpressionDecoder output harmonic distribution', savepath=fig_savepath, saveprefix=pref, showfig=showfig)

            if plot_io_diff:
                spectrogram_difference = tf.abs(midi_mel - input_mel)
                plot_tensor_2d_as_image(spectrogram_difference, title='ExpressionDecoder IO diff', savepath=fig_savepath, saveprefix=pref, showfig=showfig)


            print(f'{experiment_name}: instrument={instrument_name}: midi_audio: ')
            play(midi_audio)

        print()

In [None]:
def perform_timbre_transfer(midi_ddsp_model, inputs, target_instrument_abb='vn'):
   
    inputs_internal = deepcopy(inputs)
    
    inputs_internal['instrument_id_gt'] = inputs_internal['instrument_id']
    
    if target_instrument_abb:
        target_instrument_id = tf.constant(midi_ddsp.data_handling.instrument_name_utils.INST_ABB_TO_ID_DICT[target_instrument_abb], shape=[1])
        batch_size = inputs['audio'].shape[0]
        inputs_internal['instrument_id'] = tf.repeat(target_instrument_id, batch_size)
        
    outputs = midi_ddsp_model(inputs_internal)
    
    return outputs, inputs_internal 

In [None]:
def eval_io(midi_ddsp_model, 
            inputs, 
            outputs,
            experiment_name,
            target_instrument_abb='vn',
            plot=False,
            fig_subfolder=None,
            showfig=False,
            save_io_dir=None,
            save_io_midi=True,
            save_io_synthcoder=False):
    
    if fig_subfolder:
        fig_savepath = os.path.join(save_io_dir, experiment_name, fig_subfolder)
    else:
        fig_savepath = None
        
    play_results(inputs=inputs,
                outputs=outputs,
                experiment_name=experiment_name,
                show_midi_ddsp=True,
                plot_inputs=plot,
                plot_synthcoder=plot,
                plot_midi_ddsp=plot,
                plot_io_diff=plot,
                plot_expression_features=plot,
                fig_savepath=fig_savepath,
                showfig=showfig)
    
    if save_io_dir:
        
        save_io_full_dir = os.path.join(save_io_dir, experiment_name)
        
        if not os.path.exists(save_io_full_dir):
            # Create the folder if it doesn't exist
            os.makedirs(save_io_full_dir)
        
        for batch_idx in range(inputs['audio'].shape[0]):
            
            original_instrument_abb = midi_ddsp.data_handling.instrument_name_utils.INST_ID_TO_ABB_DICT[int(inputs["instrument_id_gt"][batch_idx])]
            save_wav(audio_tensor=inputs['audio'][batch_idx], filename=f'{batch_idx+1}_{original_instrument_abb}_in.wav', dir=save_io_full_dir)
            
            if save_io_synthcoder:
                save_wav(audio_tensor=outputs['synth_audio'][batch_idx], filename=f'{batch_idx+1}_{target_instrument_abb}_sc_out.wav', dir=save_io_full_dir)
            
            if save_io_midi:
                save_wav(audio_tensor=outputs['midi_audio'][batch_idx], filename=f'{batch_idx+1}_{target_instrument_abb}_out.wav', dir=save_io_full_dir)

In [None]:
def eval_timbre_transfer(midi_ddsp_model, 
                         inputs, 
                         experiment_name,
                         target_instrument_abb='vn',
                         plot=False,
                         fig_subfolder=None,
                         showfig=False,
                         save_io_dir=None,
                         save_io_midi=True,
                         save_io_synthcoder=False):
    
    outputs, inputs = perform_timbre_transfer(midi_ddsp_model=midi_ddsp_model, inputs=inputs, target_instrument_abb=target_instrument_abb)
    
    return eval_timbre_transfer(midi_ddsp_model=midi_ddsp_model, 
                     inputs=input_dict, 
                     outputs=outputs,
                     experiment_name=experiment_name,
                     target_instrument_abb=target_instrument_abb,
                     plot=plot,
                     fig_subfolder=fig_subfolder,
                     showfig=showfig,
                     save_io_dir=save_io_dir,
                     save_io_midi=save_io_midi,
                     save_io_synthcoder=save_io_synthcoder)

In [None]:
eval_timbre_transfer(midi_ddsp_model=magenta_midi_ddsp_all, 
                     inputs=next(iter(selected_eval_dataset)), 
                     experiment_name='my_midi_ddsp_vn',
                     target_instrument_abb='vn',
                     plot=False,
                     fig_subfolder=None,
                     showfig=False,
                     save_io_dir='E:/Code/TimbreTransfer_ExperimentExamples/MIDI_DDSP/auto/')

In [None]:
eval_timbre_transfer(midi_ddsp_model=my_midi_ddsp_vn, 
                     inputs=next(iter(selected_eval_dataset)), 
                     experiment_name='my_midi_ddsp_vn_sc_to_vn',
                     target_instrument_abb='vn',
                     plot=True,
                     fig_subfolder='specs',
                     showfig=False,
                     save_io_dir='E:/Code/TimbreTransfer_ExperimentExamples/MIDI_DDSP/auto/',
                     save_io_midi=False,
                     save_io_synthcoder=True)

In [None]:
eval_timbre_transfer(midi_ddsp_model=midi_ddsp_no_mel_all, 
                     inputs=next(iter(selected_eval_dataset)), 
                     experiment_name='midi_ddsp_no_mel_all',
                     target_instrument_abb='vn',
                     plot=True,
                     fig_subfolder='specs',
                     showfig=False,
                     save_io_dir='E:/Code/TimbreTransfer_ExperimentExamples/MIDI_DDSP/auto/',
                     save_io_midi=False,
                     save_io_synthcoder=True)

In [None]:
from midi_ddsp.utils.inference_utils import get_process_group


def test_model_resolution(midi_ddsp_model, 
                               input_dict, 
                               downsample_factor, 
                               should_plot_f0, 
                               experiment_name,
                               target_instrument_abb='vn',
                               plot=False,
                               fig_subfolder=None,
                               showfig=False,
                               save_io_dir=None,
                               save_io_midi=True,
                               save_io_synthcoder=False,
                               downsampling_target='inputs',
                               run_synth_coder_only=True,):
    
    params_name = 'synth_params' if run_synth_coder_only else 'midi_synth_params'
    original_run_synth_coder_only = midi_ddsp_model.run_synth_coder_only
    
    midi_ddsp_model.run_synth_coder_only = run_synth_coder_only
    if downsampling_target == 'inputs':
        downsamped_inputs = downsample_inputs(input_dict, factor=downsample_factor)
        shapes = {k: v.shape[1] if len(v.shape) >= 2 else 0 for k, v in downsamped_inputs.items()}
        print(f"Number of frames in input params: {shapes}")
        
        outputs = midi_ddsp_model(downsamped_inputs)
        output_synth_params = outputs[params_name]
        
    elif downsampling_target == 'outputs':
        outputs = midi_ddsp_model(input_dict)
        output_synth_params = downsample_synth_params(outputs[params_name], factor=downsample_factor)
        
    else:
        raise ValueError(f"no such downsampling_target: {downsampling_target}")
        
    midi_ddsp_model.run_synth_coder_only = original_run_synth_coder_only
        
    if should_plot_f0:
        line_plot_tensor_2d(input_dict['f0_hz'], title=experiment_name)

    shapes = {k: v.shape[1] if len(v.shape) >= 2 else 0 for k, v in output_synth_params.items()}
    print(f"Number of frames in synth params: {shapes}")

    processor_group = get_process_group(n_frames=1000//downsample_factor,
                                        frame_size=64*downsample_factor,
                                        sample_rate=16000,
                                        use_angular_cumsum=False)
    
    
    output_audio = get_audio(synthcoder_output=output_synth_params, processor_group=processor_group)
    output_audio = midi_ddsp_model.reverb_module(output_audio, reverb_number=input_dict['instrument_id'], training=False)
    
    outputs[params_name] = output_synth_params
    
    print(f'output_audio={output_audio}')
    
    downsamped_inputs['instrument_id_gt'] = downsamped_inputs['instrument_id']

    eval_io(midi_ddsp_model=midi_ddsp_model, 
                     inputs=downsamped_inputs, 
                     outputs=outputs,
                     experiment_name=experiment_name,
                     target_instrument_abb=target_instrument_abb,
                     plot=plot,
                     fig_subfolder=fig_subfolder,
                     showfig=showfig,
                     save_io_dir=save_io_dir,
                     save_io_midi=save_io_midi,
                     save_io_synthcoder=save_io_synthcoder)

In [None]:
from midi_ddsp.utils.inference_utils import get_process_group


def test_synthcoder_resolution(midi_ddsp_model, input_dict, downsample_factor, should_plot_f0, title, downsampling_target='inputs', 
                               run_synth_coder_only=True):
    params_name = 'synth_params' if run_synth_coder_only else 'midi_synth_params'
    original_run_synth_coder_only = midi_ddsp_model.run_synth_coder_only
    midi_ddsp_model.run_synth_coder_only = run_synth_coder_only
    if downsampling_target == 'inputs':
        downsamped_inputs = downsample_inputs(input_dict, factor=downsample_factor)
        shapes = {k: v.shape[1] if len(v.shape) >= 2 else 0 for k, v in downsamped_inputs.items()}
        print(f"Number of frames in input params: {shapes}")

        output_synth_params = midi_ddsp_model(downsamped_inputs)[params_name]
    elif downsampling_target == 'outputs':
        output_synth_params = downsample_synth_params(midi_ddsp_model(input_dict)[params_name], factor=downsample_factor)
    else:
        raise ValueError(f"no such downsampling_target: {downsampling_target}")
    midi_ddsp_model.run_synth_coder_only = original_run_synth_coder_only
        
    if should_plot_f0:
        line_plot_tensor_2d(input_dict['f0_hz'], title=title)

    shapes = {k: v.shape[1] if len(v.shape) >= 2 else 0 for k, v in output_synth_params.items()}
    print(f"Number of frames in synth params: {shapes}")

    processor_group = get_process_group(n_frames=1000//downsample_factor,
                                        frame_size=64*downsample_factor,
                                        sample_rate=16000,
                                        use_angular_cumsum=False)
    
    
    output_audio = get_audio(synthcoder_output=output_synth_params, processor_group=processor_group)
    output_audio = midi_ddsp_model.reverb_module(output_audio, reverb_number=input_dict['instrument_id'], training=False)
    
    print('input audio:')
    play(input_dict['audio'])
    
    print('output audio:')
    play(output_audio)

    output_audio_internal = midi_ddsp_model(downsample_synth_params(input_dict, factor=downsample_factor))['synth_audio']
    
    print('output audio internal:')
    play(output_audio_internal)

    
    return output_audio


def test_synthcoder(midi_ddsp_model, input_dict, should_plot_f0, title):
    output = midi_ddsp_model(input_dict)

    if should_plot_f0:
        line_plot_tensor_2d(input_dict['f0_hz'], title=title)

    print('input audio:')
    play(input_dict['audio'])
    
    print('output audio:')
    play(output['synth_audio'])

    return output['synth_audio']

In [None]:
from midi_ddsp.data_handling.instrument_name_utils import INST_ID_TO_NAME_DICT, INST_ABB_TO_ID_DICT


vn_id_constant = tf.constant(INST_ABB_TO_ID_DICT['vn'], shape=[1])

vn_training_example_base = deepcopy(vn_training_example)

cl_training_example_base = deepcopy(cl_training_example)

cl_training_example_oct_up = deepcopy(cl_training_example_base)
cl_training_example_oct_up['f0_hz'] = 2 * cl_training_example_oct_up['f0_hz']
cl_training_example_oct_up['instrument_id'] = vn_id_constant

cl_training_example_oct_up_pw2 = deepcopy(cl_training_example_oct_up)
cl_training_example_oct_up_pw2['audio'] = tf.random.uniform(tf.shape(cl_training_example_oct_up_pw2['audio']))
cl_training_example_oct_up_pw2['instrument_id'] = vn_id_constant

cl_training_example_replaced_urmp_violin = deepcopy(cl_training_example_base)
cl_training_example_replaced_urmp_violin['audio'] = vn_training_example['audio']
cl_training_example_replaced_urmp_violin['instrument_id'] = vn_id_constant

cl_training_example_oct_up_replaced_urmp_violin = deepcopy(cl_training_example_oct_up)
cl_training_example_oct_up_replaced_urmp_violin['audio'] = vn_training_example['audio']
cl_training_example_oct_up_replaced_urmp_violin['instrument_id'] = vn_id_constant

pluck_arpeggio = read_wav('pluck_arpeggio.wav')
cl_training_example_oct_up_replaced_audio_pluck_arpeggio = deepcopy(cl_training_example_oct_up)
cl_training_example_oct_up_replaced_audio_pluck_arpeggio['audio'] = tf.reshape(pluck_arpeggio, [1, 64000])
cl_training_example_oct_up_replaced_audio_pluck_arpeggio['instrument_id'] = vn_id_constant

pluck_arpeggio_bass = read_wav('pluck_arpeggio_bass.wav')
cl_training_example_oct_up_replaced_audio_pluck_arpeggio_bass = deepcopy(cl_training_example)
cl_training_example_oct_up_replaced_audio_pluck_arpeggio_bass['audio'] = tf.reshape(pluck_arpeggio_bass, [1, 64000])
cl_training_example_oct_up_replaced_audio_pluck_arpeggio_bass['instrument_id'] = vn_id_constant

vn_orchestra_downward = read_wav('vn_orchestra_downward.wav')
cl_training_example_oct_up_replaced_audio_vn_orchestra_downward = deepcopy(cl_training_example_oct_up)
cl_training_example_oct_up_replaced_audio_vn_orchestra_downward['audio'] = tf.reshape(vn_orchestra_downward, [1, 64000])
cl_training_example_oct_up_replaced_audio_vn_orchestra_downward['instrument_id'] = vn_id_constant


instr_name = INST_ID_TO_NAME_DICT[int(cl_training_example_base['instrument_id'])]

samples_to_run = {
    f'Inference on urmp violin': vn_training_example_base,
    f'Inference on urmp clarinet': cl_training_example_base,
    f'Inference on urmp clarinet octave up': cl_training_example_oct_up,
    f'Inference on urmp clarinet octave up with audio of noise': cl_training_example_oct_up_pw2,
    f'Inference on urmp clarinet with audio of urmp violin': cl_training_example_replaced_urmp_violin,
    f'Inference on urmp clarinet octave up with audio of urmp violin': cl_training_example_oct_up_replaced_urmp_violin,
    f'Inference on urmp clarinet octave up with audio of pluck arp synth': cl_training_example_oct_up_replaced_audio_pluck_arpeggio,
    f'Inference on urmp clarinet with audio of bass pluck arp synth': cl_training_example_oct_up_replaced_audio_pluck_arpeggio_bass,
    f'Inference on urmp clarinet octave up with audio of kontakt vn orchestra': cl_training_example_oct_up_replaced_audio_vn_orchestra_downward,
}

In [None]:
test_model_resolution(midi_ddsp_model=my_midi_ddsp_vn, 
                     input_dict=next(iter(selected_eval_dataset)), 
                     experiment_name='my_midi_ddsp_vn_downsample_by_8',
                     target_instrument_abb='vn',
                     plot=True,
                     fig_subfolder='specs',
                     showfig=False,
                     save_io_dir='E:/Code/TimbreTransfer_ExperimentExamples/MIDI_DDSP/auto/',
                     save_io_midi=True,
                     save_io_synthcoder=False,
                     downsample_factor=8,
                     should_plot_f0=True,
                     downsampling_target='inputs',
                     run_synth_coder_only=False)

In [None]:
test_model_resolution(midi_ddsp_model=my_midi_ddsp_vn, 
                     input_dict=next(iter(selected_eval_dataset)), 
                     experiment_name='my_midi_ddsp_vn_sc_downsample_by_8',
                     target_instrument_abb='vn',
                     plot=True,
                     fig_subfolder='specs',
                     showfig=False,
                     save_io_dir='E:/Code/TimbreTransfer_ExperimentExamples/MIDI_DDSP/auto/',
                     save_io_midi=False,
                     save_io_synthcoder=True,
                     downsample_factor=8,
                     should_plot_f0=True,
                     downsampling_target='inputs',
                     run_synth_coder_only=False)

In [None]:
for k, test_sample in samples_to_run.items():
    
    test_synthcoder_resolution(midi_ddsp_model=my_midi_ddsp_vn,
                               input_dict=test_sample,
                               downsample_factor=8,
                               should_plot_f0=True,
                               title=k,
                               downsampling_target='inputs',
                              run_synth_coder_only=False)

In [None]:
test_synthcoder_resolution(synthcoder_model=synthcoder,
                           input_dict=cl_training_example,
                           downsample_factor=2)

In [None]:
test_synthcoder_resolution(synthcoder_model=synthcoder,
                           input_dict=cl_training_example,
                           downsample_factor=4)

In [None]:
test_synthcoder_resolution(synthcoder_model=synthcoder,
                           input_dict=cl_training_example,
                           downsample_factor=8)

In [None]:
test_synthcoder_resolution(synthcoder_model=synthcoder,
                           input_dict=cl_training_example,
                           downsample_factor=20)

In [None]:
from copy import deepcopy

cl_training_example_with_vn_instrument = deepcopy(cl_training_example)
cl_training_example_with_vn_instrument['instrument_id'] = tf.constant([20])

print(cl_training_example['instrument_id'])
print(cl_training_example['instrument_id'].shape)
print(cl_training_example_with_vn_instrument['instrument_id'])
print(cl_training_example_with_vn_instrument['instrument_id'].shape)

play(midi_ddsp_model(cl_training_example)['synth_audio'])

In [None]:
cl_training_example['f0_hz']

In [None]:
import pickle

with open(os.path.join(midi_ddsp_models_dir, 'my_vn_full_model/losses_history_e235_s28669'), 'rb') as file:
    # Load the object from the file
    losses_history = pickle.load(file)

with open(os.path.join(midi_ddsp_models_dir, 'my_vn_full_model\\eval_losses_history_e235_s28669'), 'rb') as file:
    # Load the object from the file
    eval_losses_history = pickle.load(file) 



In [None]:
print(losses_history)

In [None]:
import plotly.graph_objects as go

# Example data
train_loss = [0.5, 0.4, 0.3, 0.2, 0.15, 0.1]
eval_loss = [0.6, 0.5, 0.4, 0.3, 0.25, 0.2]

# Create a trace for training loss
train_trace = go.Scatter(
    x=list(range(1, len(train_loss) + 1)),
    y=train_loss,
    mode='lines',
    name='Train Loss'
)

# Create a trace for evaluation loss
eval_trace = go.Scatter(
    x=list(range(1, len(eval_loss) + 1)),
    y=eval_loss,
    mode='lines',
    name='Eval Loss'
)

# Create the layout
layout = go.Layout(
    title='Training and Evaluation Loss',
    xaxis=dict(title='Epoch'),
    yaxis=dict(title='Loss')
)

# Create the figure and add the traces
fig = go.Figure(data=[train_trace, eval_trace], layout=layout)

# Display the figure
fig.show()