## Background
* In the paper https://arxiv.org/pdf/1711.02037.pdf, after doing $A \approx QB$, they use a rather complicated way to do HALS in a small subspace. 


* It seems that we can directly use $A \approx QB$ in the HALS update formula (equation (14) (15) in the paper) and avoid computation of $O(np)$


* I did a naive implementation of it and compare it with the author's algorithms. Below, their HALS algorithm is called `nmf`, and their randomized + HALS is called `rnmf`. My naive implementation is called `rnmf2`. (Note that they use `_update_cdnmf_fast` from sklearn for hals, while I use my naive function). 

* Code is https://github.com/zihao12/ristretto/blob/master/ristretto/nmf.py

## Result:
* Speed is as expected. Althouth `rnmf2` is slightly slower than `rmnf`, they are both much faster than `nmf`. It is probably due to implementation details

* It is a bit worrying that `rnmf2` lags behind in loss (I have tried with data of different sizes, and it is always `rnmf2` lags behind `rnmf`,which is behind `nmf`)

## Generate data and compute oracle loss

In [1]:
import warnings
#warnings.filterwarnings("ignore",category =RuntimeWarning)
import numpy as np
from ristretto import nmf
from sklearn.decomposition import NMF
import time
np.random.seed(123)

## generate data
p = 10000
n = 5000
K = 5
W = np.exp(np.random.uniform(size = (p,K)))
H = np.exp(np.random.uniform(size = (K,n)))
Lam = W.dot(H)
A = np.random.poisson(Lam, size = (p,n))

# compute oracle loss
cost_oracle = nmf.cost(A, W, H)
print("oracle loss: {}".format(cost_oracle))

oracle loss: 739609772.7034554


## nmf

In [2]:
print("nmf")
start = time.time()
(W_nmf,H_nmf) = nmf.compute_nmf(A=A,rank = K, maxiter=200)
runtime = time.time() - start
print("runtime: {}".format(runtime))
print("loss: {}".format(nmf.cost(A,W_nmf,H_nmf)))

nmf
runtime: 119.73408222198486
loss: 738505818.5604401


In [3]:
print("rnmf")
start = time.time()
(W_rnmf,H_rnmf) = nmf.compute_rnmf(A=A,rank = K, maxiter=200, oversample=20)
runtime = time.time() - start
print("runtime: {}".format(runtime))
print("loss: {}".format(nmf.cost(A,W_rnmf,H_rnmf)))

rnmf
runtime: 8.516527891159058
loss: 738516117.0531319


In [4]:
print("rnmf2")
start = time.time()
(W_rnmf2,H_rnmf2) = nmf.compute_rnmf2(A=A,rank = K, maxiter=200,oversample=20)
runtime = time.time() - start
print("runtime: {}".format(runtime))
print("loss: {}".format(nmf.cost(A,W_rnmf2,H_rnmf2)))

rnmf2
runtime: 11.374766826629639
loss: 738689006.4422762
