In [87]:
from scipy.linalg import orth
import numpy as np
import matplotlib.pyplot as plt

# Fast SVD approximation using the Moore-Penrose pseudoinverse !


In [244]:
# nxm matrix
n=500
m=800
k=10 # maximum number of singular values we use for estimation
A = np.random.rand(m,n)
U,S,V = np.linalg.svd(A,full_matrices=False)

In [245]:
print(U.shape)
print(S.shape)
print(V.shape)

(800, 500)
(500,)
(500, 500)


In [246]:
S[k+1:] = S[k+1:]*(np.arange(k+1,min(n,m),dtype=float).conjugate().T)[:]**-1 #  is the numpy complex conjugate operator
S = np.diag(S,)
print(S.shape)
A = U@S@V.conjugate().T
print(A.shape)

(500, 500)
(800, 500)


In [247]:
print(S[k+1].shape)
print()

(500,)



In [248]:
def random_sketch_approx(A,k_over):
    n = min(A.shape)
    if A.shape[0] > A.shape[1]:
        min_ax = 0
    else:
        min_ax = 1
    k_over = 2*k # The fix? Take k+10
    ds = np.exp(1j*2*np.pi*np.random.rand(n,1)) # generate random gausian vectors
    sm = np.random.randint(0,high=n,size=k_over) 
    
    a_sketch = A.T
    a_sketch = a_sketch*ds
    a_sketch = np.fft.fft(a_sketch,n,axis=min_ax)/np.sqrt(n) # n point fft over the longer axis
    a_sketch = a_sketch[sm,:]
    a_sketch = a_sketch.T

    q = orth(a_sketch)
    A_small = q.conjugate().T@A
    [uu,ss,vv] = np.linalg.svd(A_small,full_matrices=False)

    return [uu,ss,vv], q


# Test

In [249]:
[uu,ss,vv], q = random_sketch_approx(A,15)

A_error = np.real(q@q.conjugate().T@A)
norm = np.linalg.norm(A-A_error)
print("Absolute Error (norm): ",norm)
sigs_exact = np.diag(S)
print("Relative error:",(sigs_exact[:k]-ss[:k])/sigs_exact[:k])

Absolute Error (norm):  4.117761511988246
Relative error: [8.66681750e-06 1.74732326e-03 2.87798434e-03 1.65844722e-03
 2.26399367e-03 2.07205488e-03 2.51777592e-03 2.49454076e-03
 3.76928998e-03 4.05329820e-03]


# Time Trials 
## Numpy svd (np.linalg.svd(A))

In [250]:
%%timeit
U,S,V = np.linalg.svd(A,full_matrices=False)

50.3 ms ± 1.94 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Numpy svd (np.linalg.svd(A))

In [251]:
%%timeit
[uu,ss,vv], q = random_sketch_approx(A,10)




17 ms ± 2.9 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
