In [None]:
import numpy as np
import matplotlib.pyplot as plt
from nmf.numf.base import numf
from nmf.numf.peaks import get_neighbors
from nmf.numf.multigrid import create_R, get_fine_p

In [None]:
def gauss(x, sigma=1, mean=0, scale=1):
    return scale * np.exp(-np.square(x - mean) / (2 * sigma ** 2))

m = 100
r = 3
p1 = 24
p2 = 50
p3 = 76

x = np.linspace(1, m, m).reshape(-1, 1)
w1 = gauss(x, sigma=2, mean=p1)
w2 = np.concatenate((np.zeros((int((m - p2) / 2), 1)), np.ones((p2, 1)), np.zeros((int((m - p2) / 2), 1))))
w3 = gauss(x, sigma=2, mean=p3)
Wtrue = np.hstack((w1, w2, w3))


n = 6
c = 1/np.sqrt(r-1)
e = 0.001
Htrue = np.array([[c + e, 1-c-e, 0],
                  [1-c-e, c+e, 0],
                  [c+e, 0, 1-c-e],
                  [1-c-e, 0, c+e],
                  [0, c+e, 1-c-e],
                  [0, 1-c-e, c+e]]).T

M = Wtrue @ Htrue

In [None]:
plt.plot(M)

### No Multigrid

In [None]:
iters = 200
# pvals = get_neighbors([p1, p2, p3], m, 5)

In [None]:
W0 = np.random.rand(m, r)
H0 = np.random.rand(r, n)

In [None]:
%%time
_, _, _ = numf(M, W0.copy(), H0.copy(), iters=iters)

### Multigrid 1-Level

In [None]:
%%time
R1 = create_R(m)
_, H1, pouts1 = numf(R1 @ M, R1 @ W0.copy(), H0.copy(), iters=iters)
fine_pouts1 = get_fine_p(pouts1)
_, _, _ = numf(M, W0.copy(), H1, pvals=get_neighbors(fine_pouts1, m, 3), iters=iters)

### Multigrid Level-2

In [None]:
%%time
R1 = create_R(m)
R2 = create_R(R1.shape[0])
_, H2, pouts2 = numf(R2 @ R1 @ M, R2 @ R1 @ W0.copy(), H0.copy(), iters=iters)
fine_pouts2 = get_fine_p(pouts2, scaling_factor=4)
_, _, _ = numf(M, W0.copy(), H2, pvals=get_neighbors(fine_pouts2, m, 3), iters=iters)

### Regularization Test

In [None]:
pvals = get_neighbors([p1, p2, p3], m, 5)

# l2 = 0
W0, H0, pouts0 = numf(M, W0.copy(), H0.copy(), pvals=pvals, iters=iters, l2=0)
plt.plot(W0)

In [None]:
# l2 = 0.3
W1, H1, pouts1 = numf(M, W0.copy(), H0.copy(), pvals=pvals, iters=iters, l2=0.3)
plt.plot(W1)

In [None]:
# l2 = 0.6
W2, H2, pouts2 = numf(M, W0.copy(), H0.copy(), pvals=pvals, iters=iters, l2=0.6)
plt.plot(W2)