In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import tensorflow as tf
import numpy as np
import pickle
import librosa

import matplotlib.pyplot as plt
import IPython.display as ipd

from vyae import create_models

In [None]:
run_id = "20220426_17414_sb_full"

with open('./results/%s/info.dict' % run_id, "rb") as file:
    try:
        data = pickle.load(file)
        hist = data["adverse_history"]
        args = data["args"]
        
        encoder, decoder, autoencoder, discriminator, audio_pre, audio_post = create_models(args, weights=run_id)
        
        plt.figure(figsize=(8,5))
        # plt.plot(hist['L'], label="L")
        plt.plot(hist['L_ms'], label="L_ms")
        # plt.plot(hist['L_kl'], label="L_kl")
        # plt.plot(hist['L_ec'], label="L_ec")
        # plt.plot(hist['L_re'], label="L_re")
        # plt.plot(hist['L_ii'], label="L_ii")
        plt.legend()
        # plt.gca().set_yscale('log')
        plt.xlabel("Epoch #")
        plt.ylabel("Loss [arb.]")
        plt.show()
    except Exception as e:
        print(e)
    


In [None]:
# process a file frame-wise with the model

x, fs2 = librosa.load("data/test_input.wav", sr=args.fs)
assert args.fs == fs2, "Sampling rate does not match."
T = len(x)

O = 128
H = args.L - O
y = np.zeros_like(x)
fade_in = np.linspace(0,1,O) # np.sin(0.5*np.pi*np.linspace(0, 1, O))
fade_out = np.linspace(1,0,O) # np.cos(0.5*np.pi*np.linspace(0, 1, O))

target_class = 12

for s in np.arange(0, T, H):
    if s+args.L > T:
        continue
    S, P = audio_pre(x[s:s+args.L].reshape(1,1,args.L))
    if args.use_vae:
        e,i,_,_ = encoder(S)
    else:
        e,i = encoder(S)

    R = decoder([tf.squeeze(tf.one_hot([[target_class]], 16), axis=1),i])
    out = audio_post([R, P])[0,0].numpy()
    out[:O] *= fade_in
    out[-O:] *= fade_out
    y[s:s+args.L] += out

print("Original:")
ipd.display(ipd.Audio(x, rate=args.fs))
print("Resynthesized:")
ipd.display(ipd.Audio(y, rate=args.fs))

In [None]:
# plot spectrogram of an excerpt
start = int(7*args.fs)
end = start + int(6.05*args.fs)

N = 2048
H = 512
mini = -60
maxi = 10

x_e = x[start:end]
S_x = np.clip(20 * np.log10(np.abs(librosa.stft(x_e, n_fft=N, hop_length=H))), mini, maxi)

y_e = y[start:end]
S_y = np.clip(20 * np.log10(np.abs(librosa.stft(y_e, n_fft=N, hop_length=H))), mini, maxi)

t = np.arange(S_x.shape[1]) * H / args.fs
f = np.fft.rfftfreq(N, 1/args.fs)

fig, axs = plt.subplots(1, 2, figsize=(12, 4.6), dpi=100, sharex=True, sharey=True, tight_layout=True)

axs[0].set_title("Original")
axs[0].pcolormesh(t, f, S_x)
axs[1].set_title("Resynthesized")
axs[1].pcolormesh(t, f, S_y)

axs[0].set_ylabel("Frequency [Hz]")

axs[0].set_xlabel("Time [s]")
axs[1].set_xlabel("Time [s]")

axs[0].set_yscale("log")
axs[0].set_ylim(20, 8000)
axs[0].set_yticks([50, 100, 250, 500, 1000, 2500, 5000], [50, 100, 250, 500, 1000, 2500, 5000])

plt.show()