# PS4-4 ICA

### (a) Gaussian source.

When the source is drawn from standard Gaussian distribution, from notes we know

\begin{equation}\notag
l(W) = \sum_{i=1}^n\left(\log|W|+\sum_{j=1}^d\log\frac{1}{\sqrt{2\pi}}\exp\left(-\frac{\left(w_j^Tx^{(i)}\right)^2}{2}\right)\right).
\end{equation}

Then perform standard MLE.

\begin{align*}
\nabla_Wl(W)&=n\left(W^{-1}\right)^T-\frac{1}{2}\sum_{i=1}^n\nabla_W\sum_{j=1}^d\left(w_j^Tx^{(i)}\right)^2\\
&=n\left(W^{-1}\right)^T-\frac{1}{2}\sum_{i=1}^n2Wx^{(i)}\left(x^{(i)}\right)^T\\
&=n\left(W^{-1}\right)^T-W\sum_{i=1}^nx^{(i)}\left(x^{(i)}\right)^T\\
&=n\left(W^{-1}\right)^T-WX^TX
\end{align*}

Let $\nabla_Wl(W_0)=0$, we obtain

\begin{align*}
W_0^TW_0=n(X^TX)^{-1}
\end{align*}

The ambiguity in computing $W$ is that there are multiple $W$'s satisfying the condition above, as long as 

$$W=OW_0,$$

where $O$ is an orthogonal matrix and thus $W^TW=W_0^TW_0$.

### (b) Laplace source.

When the source is drawn from standard Laplace distribution, the log-likelihood function for a single sample $x^{(i)}$ becomes
\begin{align*}
l(W) = \log|W|+\sum_{j=1}^d\log\frac{1}{2}\exp\left(-|w_j^Tx^{(i)}|\right).
\end{align*}

Then compute the gradient.

\begin{align*}
\nabla_Wl(W)&=\left(W^{-1}\right)^T-\nabla_W\sum_{j=1}^d|w_j^Tx^{(i)}|\\
&=\left(W^{-1}\right)^T-\mathrm{sign}\left(Wx^{(i)}\right)\left(x^{(i)}\right)^T\\
\end{align*}

And the SGA update rule is

$$W:=W+\alpha\left(\left(W^{-1}\right)^T-\mathrm{sign}\left(Wx^{(i)}\right)\left(x^{(i)}\right)^T\right).$$

### (c) Implementation.

In [None]:
import numpy as np
import scipy.io.wavfile
import os
import numpy as np

def update_W(W, x, learning_rate):
    """
    Perform a gradient ascent update on W using data element x and the provided learning rate.

    This function should return the updated W.

    Use the laplace distribiution in this problem.

    Args:
        W: The W matrix for ICA
        x: A single data element
        learning_rate: The learning rate to use

    Returns:
        The updated W
    """
    
    # *** START CODE HERE ***
    updated_W = W + learning_rate * (np.linalg.inv(W).T - np.outer(np.sign(W @ x), x.T))
    # *** END CODE HERE ***

    return updated_W


def unmix(X, W):
    """
    Unmix an X matrix according to W using ICA.

    Args:
        X: The data matrix
        W: The W for ICA

    Returns:
        A numpy array S containing the split data
    """

    S = np.zeros(X.shape)


    # *** START CODE HERE ***
    S = X @ W.T
    # *** END CODE HERE ***

    return S


Fs = 11025

def normalize(dat):
    return 0.99 * dat / np.max(np.abs(dat))

def load_data():
    mix = np.loadtxt('./data/mix.dat')
    return mix

def save_W(W):
    np.savetxt('output/W.txt',W)

def save_sound(audio, name):
    scipy.io.wavfile.write('output/{}.wav'.format(name), Fs, audio)

def unmixer(X):
    M, N = X.shape
    W = np.eye(N)

    anneal = [0.1 , 0.1, 0.1, 0.05, 0.05, 0.05, 0.02, 0.02, 0.01 , 0.01, 0.005, 0.005, 0.002, 0.002, 0.001, 0.001]
    print('Separating tracks ...')
    for lr in anneal:
        print(lr)
        rand = np.random.permutation(range(M))
        for i in rand:
            x = X[i]
            W = update_W(W, x, lr)

    return W

In [32]:
# Seed the randomness of the simulation so this outputs the same thing each time
np.random.seed(0)
X = normalize(load_data())

print(X.shape)

for i in range(X.shape[1]):
    save_sound(X[:, i], 'mixed_{}'.format(i))

W = unmixer(X)
print(W)
save_W(W)
S = normalize(unmix(X, W))
assert S.shape[1] == 5
for i in range(S.shape[1]):
    if os.path.exists('split_{}'.format(i)):
        os.unlink('split_{}'.format(i))
    save_sound(S[:, i], 'split_{}'.format(i))

(53442, 5)
Separating tracks ...
0.1
0.1
0.1
0.05
0.05
0.05
0.02
0.02
0.01
0.01
0.005
0.005
0.002
0.002
0.001
0.001
[[ 52.83337868  16.79535173  19.94059268 -10.1982649  -20.8969462 ]
 [ -9.93333916  -0.97878167  -4.67969942   8.0443386    1.78975024]
 [  8.31132067  -7.47665691  19.31501843  15.17431965 -14.32607253]
 [-14.66728796 -26.64498039   2.44086049  21.38241345  -8.42100435]
 [ -0.26910696  18.37442886   9.31246685   9.10279432  30.59440916]]
