In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import librosa
import librosa.display
import matplotlib.pyplot as plt

import numpy as np
import scipy.signal

import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_probability as tfp
from tensorflow.keras import backend as K

K.set_floatx('float64')

from tensorflow.keras import metrics
from tensorflow.keras import layers
from tensorflow.keras import optimizers
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Lambda, Flatten, Reshape
from tensorflow.keras.layers import Layer, Add, Multiply, Conv1D, Conv1DTranspose
from tensorflow_addons.layers import WeightNormalization
from tensorflow.keras.initializers import RandomUniform, Constant

from tf_extensions.tf_custom.layers import MixingBlock, Snake
from tf_extensions.tf_custom.models import GaussianBetaVAE

#### Load the spectral audio data 

In [2]:
import joblib

_PATH_TO_AUDIO_DATA = "./dance_wav/audio_spectral_data_stft.pkl"
spectral_audio_dataset = joblib.load(_PATH_TO_AUDIO_DATA)

meta_data = spectral_audio_dataset["MetaInfo"]
SR        = meta_data["SampleRate"]
DUR       = meta_data["ClipDuration"]
OVERLAP   = meta_data["Overlap"]
NFFT      = meta_data["Num_fft"]
FFTWIN    = meta_data["FFTWindow"]

stft_frames = spectral_audio_dataset["Data"]
stft_frames = np.reshape(stft_frames, newshape=(-1, np.prod(stft_frames.shape[1:])))
print(stft_frames.shape)

(139438, 1026)


####  Extract training and validation data

In [3]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler

tsize = 0.1
state = 1338

scaler = StandardScaler() #MinMaxScaler(feature_range=(-1, 1))
stft_frames_scaled = scaler.fit_transform(stft_frames)
print(stft_frames_scaled.min(), stft_frames_scaled.max())

X_train, X_test = train_test_split(
    stft_frames_scaled, test_size=tsize, random_state=state
)
print(X_train.shape, X_test.shape)

# Get the magnitude spectrograms
stft_train = tf.convert_to_tensor(
    X_train, dtype=K.floatx()
)
stft_test = tf.convert_to_tensor(
    X_test,  dtype=K.floatx() 
)
print(stft_train.shape, stft_test.shape)

-108.99208378847092 106.08346116204625
(125494, 1026) (13944, 1026)
(125494, 1026) (13944, 1026)


In [4]:
class ConvBlock(Layer):
    
    def __init__(self, maps, kernel, alpha=1, *args, **kwargs):
        super(ConvBlock, self).__init__(*args, **kwargs)
        self.conv = Conv1D(maps, kernel_size=kernel, strides=2, padding="causal")
        self.acti = Snake(alpha, trainable=True)
    
    def call(self, inputs):
        x = self.conv(inputs)
        return self.acti(x)        

class TransposeConvBlock(Layer):
    
    def __init__(self, maps, kernel, alpha=1, *args, **kwargs):
        super(TransposeConvBlock, self).__init__(*args, **kwargs)
        self.conv = Conv1DTranspose(maps, kernel_size=kernel, 
                strides=2, padding="same")
        self.acti = Snake(alpha, trainable=True)
        
    def call(self, inputs):
        x = self.conv(inputs)
        return self.acti(x)         
        
def make_encoder(input_dim, latent_dim):
    # Setup the NN inputs
    input_shape   = (np.prod(input_dim), )
    encoder_input = Input(shape=input_shape, name="encoder_input")    
    encoder_reshaped = Reshape(input_dim)(encoder_input)
    
    # First Conv1D will produce a 12-channel output
    conv_0 = Conv1D(4, kernel_size=5, strides=2, padding="valid")(encoder_reshaped)
    conv_0 = Snake()(conv_0)
    
    # Pass through a series of convolutions
    conv_1 = ConvBlock( 8, kernel=3)(conv_0)
    conv_2 = ConvBlock(12, kernel=3)(conv_1) #Conv1D(36, kernel_size=3, strides=2, padding="causal", activation="relu")(conv_1)
    conv_3 = ConvBlock(16, kernel=3)(conv_2)
    conv_4 = ConvBlock(20, kernel=3)(conv_3) #Conv1D(60, kernel_size=3, strides=2, padding="causal", activation="relu")(conv_3)
    conv_5 = ConvBlock(24, kernel=3)(conv_4)
    conv_6 = ConvBlock(28, kernel=3)(conv_5) #Conv1D(84, kernel_size=3, strides=2, padding="causal", activation="relu")(conv_5)
    
    # Flatten the data
    flat_0 = Flatten()(conv_6)
    
    # Prepare the prior distribution q(z|x)
    encoder_dense = Dense(
        2*latent_dim, name="encoder",
        kernel_initializer="zeros", bias_initializer='zeros'
    )(flat_0)
    
    return Model(inputs=[encoder_input], outputs=[encoder_dense], name="encoder")

def make_decoder(input_dim, latent_dim):
    # Setup the NN Input
    decoder_input = Input(shape=(latent_dim, ), name="decoder_input")
    
    # Use a dense layer to restructure the data
    decoder_reshaped = Dense(112)(decoder_input)
    decoder_reshaped = Snake()(decoder_reshaped)
    decoder_reshaped = Reshape((4, 28))(decoder_reshaped)
    
    # Pass through a series of convolutions
    conv_0 = TransposeConvBlock(24, kernel=3)(decoder_reshaped)
    conv_1 = TransposeConvBlock(20, kernel=3)(conv_0)
    conv_2 = TransposeConvBlock(16, kernel=3)(conv_1)
    conv_3 = TransposeConvBlock(12, kernel=3)(conv_2)
    conv_4 = TransposeConvBlock(8, kernel=3)(conv_3)
    conv_5 = TransposeConvBlock(4, kernel=3)(conv_4)
    conv_6 = TransposeConvBlock(2, kernel=3)(conv_5)
    
    flat_0 = Flatten()(conv_6)
    
    input_shape = np.prod(input_dim)
    decoder_linear_0 = Dense(2*input_shape, name="decoder_out")(flat_0)
    return Model(inputs=[decoder_input], outputs=[decoder_linear_0], name="decoder")

In [None]:
enc = make_encoder((513, 2), 64)
enc.compile()
enc.summary()

dec = make_decoder((513, 2), 64)
dec.compile()
dec.summary()

In [5]:
input_dim  = (513, 2)
latent_dim = 256
base_lr    = 4e-4

icp_model = GaussianBetaVAE(5, input_dim, latent_dim, make_encoder, make_decoder)

opt = optimizers.Adam(lr=base_lr)
icp_model.custom_compile(optimizer=opt)

In [None]:
icp_model.summary()

In [6]:
icp_model.fit(
    x=stft_train,
    y=stft_train,
    shuffle=True,
    epochs=5,
    batch_size=128,
    validation_data=(stft_test, stft_test)
)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7fe2cc65f750>

#### Inspect the latent distribution(s) 

In [None]:
def sample_mvn(x_mu, x_logvar):
    x_std = np.exp(x_logvar)
    batch = x_std.shape[0]
    
    mvn = tfp.distributions.MultivariateNormalDiag(
        loc=x_mu, scale_diag=x_std
    )
    return mvn.sample(shape=[batch]).numpy()


test_stft_frames = stft_frames_scaled.copy() #spectral_audio_dataset["Data"]
test_stft_frames = np.reshape(test_stft_frames, newshape=(-1, 173, 1026))
print(test_stft_frames.shape)

k = np.random.randint(size=1, low=0, high=806)[0]
print(k)

test_frames = test_stft_frames[k, :, :]
print(test_frames.shape)

x_mu, x_logvar, _, _ = icp_model.predict(test_frames)

pred_stft = sample_mvn(x_mu, x_logvar)
pred_stft = scaler.inverse_transform(np.clip(pred_stft, -1, 1))
pred_stft = np.reshape(pred_stft, newshape=(513, 173, 2))
print(pred_stft.shape)

S = np.zeros(pred_stft.shape[:-1], dtype=np.complex)
S.real = pred_stft[:, :, 0]
S.imag = pred_stft[:, :, 1]
M = np.abs(S)
M_db = librosa.amplitude_to_db(M)

import librosa.display
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
img = librosa.display.specshow(M_db, sr=SR,
                               y_axis='log', x_axis='time', ax=ax)

y_ = librosa.istft(stft_matrix = S, 
                   hop_length  = OVERLAP, 
                   window      = FFTWIN)

import IPython.display as ipd
ipd.display(ipd.Audio(y_, rate=SR))

In [None]:
'''
timedur_ms  = 100.
timedur_ss  = timedur_ms / 1000.
timesteps   = int(np.ceil(timedur_ss * SR / OVERLAP))
overlap     = timesteps // 2

# Reshape inputs into original "spectrums"
ORGDIM_TIME = int(np.ceil(DUR * SR / OVERLAP))
ORGDIM_FREQ = normalized_power_spectrum.shape[-1]

ax0 = (0, 0)
ax1 = (timesteps - 1, 0)
ax2 = (0, 0)

power_spectrum = normalized_power_spectrum.reshape(
    (-1, ORGDIM_TIME, ORGDIM_FREQ)
)
power_spectrum = np.pad(
    power_spectrum, [ax0, ax1, ax2], mode='constant'
)

phase_spectrum = normalized_phase_spectrum.reshape(
    (-1, ORGDIM_TIME, ORGDIM_FREQ)
)
phase_spectrum = np.pad(
    phase_spectrum, [ax0, ax1, ax2], mode='constant'
)

print(power_spectrum.shape, phase_spectrum.shape)

p = phase_spectrum.shape[0]
q = ORGDIM_TIME
r = timesteps
j = ORGDIM_FREQ

power_spectrum_timesteps = np.zeros((p, q, r, j))
phase_spectrum_timesteps = np.zeros((p, q, r, j))

def time_slices_for(arr):
    niters = ORGDIM_TIME
    slices = [arr[:, i:i+timesteps] for i in range(niters)]
    return np.array(slices)

for ix, spectrums in enumerate(zip(power_spectrum, phase_spectrum)):
    stft_arrs = np.array(spectrums)
    stft_slices = time_slices_for(stft_arrs)
    power_spectrum_timesteps[ix] = stft_slices[:, 0, :, :]
    phase_spectrum_timesteps[ix] = stft_slices[:, 1, :, :]

power_spectrum_timesteps = power_spectrum_timesteps.reshape((-1, r, j))
phase_spectrum_timesteps = phase_spectrum_timesteps.reshape((-1, r, j))
print(power_spectrum_timesteps.shape, phase_spectrum_timesteps.shape)

#print(194 // (timesteps - 1))    
#print(ORGDIM_TIME / (timesteps - overlap))
'''