In [None]:
import librosa
import scipy.signal as scipy_signal
import numpy as np
import tensorflow as tf
from hmica_learning import HMICALearner, ParallelHMICALearner
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import Audio
import pandas as pd
import os

In [None]:
parents = [os.getcwd(), 'mix', '2_speaker_8000_hz']
path = os.path.join(*parents, "2speakers/wav8k/max/cv/mix/01aa010s_0.40678_40ia010v_-0.40678.wav")
path_a = os.path.join(*parents, "2speakers/wav8k/max/cv/s1/01aa010s_0.40678_40ia010v_-0.40678.wav")
path_b = os.path.join(*parents, "2speakers/wav8k/max/cv/s2/01aa010s_0.40678_40ia010v_-0.40678.wav")

mixed_signal, sr = librosa.load(path, sr=None)
source_1, sampling_rate_1 = librosa.load(path_a, sr=None)
source_2, sampling_rate_2 = librosa.load(path_b, sr=None)

In [None]:
def butter_lowpass(cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = scipy_signal.butter(order, normal_cutoff, btype='low', analog=False)
    return b, a

def lowpass_filter(data, cutoff, fs, order=55):
    b, a = butter_lowpass(cutoff, fs, order=order)
    y = scipy_signal.lfilter(b, a, data)
    return y

def high_pass_filter(data, cutoff_freq, fs, order=5):
    """
    Apply a high-pass filter to the audio data.

    :param data: The audio data as a NumPy array.
    :param cutoff_freq: The cutoff frequency of the filter in Hz.
    :param fs: The sampling rate of the audio data in Hz.
    :param order: The order of the filter.
    :return: The filtered audio data.
    """

    nyq = 0.5 * fs
    normal_cutoff = cutoff_freq / nyq
    b, a = scipy_signal.butter(order, normal_cutoff, btype='high', analog=False)
    y = scipy_signal.filtfilt(b, a, data)
    return y

In [None]:
plt.plot(source_1)

In [None]:
plt.plot(source_2)

In [None]:
plt.plot(mixed_signal)

In [None]:
source_1.shape, source_2.shape, mixed_signal.shape

### Low and high pass filter to remove high frequency noise and low frequency noise

In [None]:
# source_1 = high_pass_filter(source_1, 300, sampling_rate_1) 
# source_2 = high_pass_filter(source_2, 300, sampling_rate_2)
# mixed_signal = high_pass_filter(mixed_signal, 300, sr)

In [None]:
plt.plot(source_1)

In [None]:
plt.plot(source_2)

In [None]:
plt.plot(mixed_signal)

In [None]:
Audio(mixed_signal, rate=sr)

In [None]:
Audio(source_1, rate=sampling_rate_1)

In [None]:
Audio(source_2, rate=sampling_rate_2)

In [None]:
mixing_matrix = np.array([[0.3, 0.7], 
                          [0.8, 0.2]])
tmp = np.vstack((source_1, source_2))
x = mixing_matrix @ tmp 
# x = np.vstack((mixed_tmp, mixed_signal))
Audio(x, rate=sr)

In [None]:
source_1, sampling_rate_1, source_2, sampling_rate_2

### STFT: (Skip)

In [None]:
nperseg = 128
f, t, Zxx = scipy_signal.stft(mixed_signal, fs=sr, nperseg=nperseg, noverlap=None)
print(f'{f.shape=}, {t.shape=}, {Zxx.shape=}')

In [None]:
# step2 - converting audio np array to spectrogram
spec = librosa.feature.melspectrogram(y=mixed_signal,
                                        sr=sr,)

In [None]:
res = librosa.feature.inverse.mel_to_audio(spec,
                                           sr=sr,)

In [None]:
Audio(res, rate=sr)

In [None]:
Audio(mixed_signal, rate=sr)

In [None]:
res[0], mixed_signal[0]

In [None]:

t2, x_hat = scipy_signal.istft(Zxx.real, fs=sr, nperseg=nperseg, noverlap=None)
print(f'{t2.shape=}, {x_hat.shape=}')

In [None]:
Audio(x_hat, rate=sr)

In [None]:
plt.plot(x_hat)

### HMICA

In [None]:
x.shape[0]

In [None]:
A_6_states = [
 [0.6500, 0.0700, 0.0700, 0.0700, 0.0700, 0.0700],
 [0.0700, 0.6500, 0.0700, 0.0700, 0.0700, 0.0700],
 [0.0700, 0.0700, 0.6500, 0.0700, 0.0700, 0.0700],
 [0.0700, 0.0700, 0.0700, 0.6500, 0.0700, 0.0700],
 [0.0700, 0.0700, 0.0700, 0.0700, 0.6500, 0.0700],
 [0.0700, 0.0700, 0.0700, 0.0700, 0.0700, 0.6500]]

A_4_states = [[0.6500, 0.1167, 0.1167, 0.1167],
 [0.1167, 0.6500, 0.1167, 0.1167],
 [0.1167, 0.1167, 0.6500, 0.1167],
 [0.1167, 0.1167, 0.1167, 0.6500]]


hmica = ParallelHMICALearner(
    k=6,  # number of states;  No-one, speaker 1, speaker 2 and both speakers (doubled to try to capture more acoustic information)
    m=2,  # number of sources
    x_dims=x.shape[0],
    use_gar=True,
    gar_order=10, # p = 8to12 should cover 1-1.5ms of speech context Speech has short-term correlations that span roughly 2-3ms
    update_interval=10,  # suggested by the paper
    learning_rates={
        'W': 1e-2,  # Unmixing matrix
        'R': 1e-2,  # Shape parameter
        'beta': 1e-2,  # Scale parameter
        'C': 1e-2  # GAR coefficients
    },
    use_analytical=False,
    A = tf.convert_to_tensor(A_6_states)
)

history = hmica.train(tf.convert_to_tensor(x.astype('float32').T)
                      , hmm_max_iter=15, ica_max_iter=5000, hmm_tol=1e-4, ica_tol=1e-2)

In [None]:
sns.lineplot(x=range(len(history['hmm_ll'])), y=history['hmm_ll'])

# Set the title and labels
plt.title("Full Model Log Likelihood")
plt.xlabel("Iteration")
plt.ylabel("Log Likelihood")

In [None]:
df_list = []
for group, values in history['ica_ll'].items():
    temp_df = pd.DataFrame({
        'Group': group,
        'Value': values,
        'Index': range(len(values))
    })
    df_list.append(temp_df)    
df = pd.concat(df_list, ignore_index=True)

sns.lineplot(data=df, x='Index', y='Value', hue='Group')

# Set the title and labels
plt.title("ICA (X|Z) Log Likelihood")
plt.xlabel("Iteration")
plt.ylabel("Log Likelihood")

In [None]:
obs_prob = lambda state, x: hmica.ica.compute_likelihood(x, state)
state_sequence, w_matrix = hmica.hmm.viterbi_for_inference(tf.convert_to_tensor(x.astype('float32').T), obs_prob)

In [None]:
sources_k = [hmica.ica.get_sources(tf.convert_to_tensor(x.astype('float32').T), k) for k in range(hmica.k)] 
sources_stacked = tf.stack(sources_k) 

In [None]:
reconstructed_sources = tf.einsum('tk,ktn->tn', state_sequence[:-1], sources_stacked)
reconstructed_sources = tf.transpose(reconstructed_sources)

In [None]:
Audio(reconstructed_sources[0].numpy(), rate=sr)

In [1]:
Audio(reconstructed_sources[1].numpy(), rate=sr)

NameError: name 'Audio' is not defined

In [None]:
tf.reduce_sum(state_sequence, axis=0)

In [None]:
hmica.ica.W