In [21]:
import numpy as np
import scipy.linalg as la
from numba import jit
from ipyparallel import Client
rc = Client()

### Pure python code

In [7]:
def thred(z,delta):
    return np.sign(z)*(np.abs(z)>=delta)*(np.abs(z)-delta)

def ssvd(X,gamu = 2, gamv =2, merr = 10**(-4), niter = 100):
    n,d = X.shape
    #initial value of u and v
    U,s,VT = la.svd(X,full_matrices=False)
    u0 = U[:,0]
    v0 = VT.T[:,0]
    
    ud = 1
    vd = 1
    count = 0
    SST  = np.sum(X*X)
    
    while(ud>merr or vd>merr):
        count = count + 1
        
        # Update v
        z = X.T @ u0
        winv = np.abs(z)**gamv # weight inverse
        sigsq = np.abs(SST-np.sum(z*z))/(n*d-d)
        
        cand = z*winv  #candidate lambda
        delt = np.sort(np.append(np.abs(cand),0))
        delt_uniq = np.unique(delt)
        Bv = np.ones(len(delt_uniq)-1)*float("inf")
        
        ind = np.where(winv>10^(-8))
        cand1 = cand[ind]
        winv1 = winv[ind]
        for i in range(len(Bv)):
            vhat= thred(cand1,delta = delt_uniq[i])
            vhat = vhat/winv1
            vshrink = np.zeros(d)
            vshrink[ind] = vhat
            Bv[i] = np.sum((X - u0[:,None] @ vshrink[None,:])**2)/sigsq + np.sum(vhat!=0)*np.log(n*d)
        
        Iv = min(np.where(Bv== np.min(Bv))) #position of min BIC
        th = delt_uniq[Iv] #best lambda in this iteration
        vhat = thred(cand1,delta = th)
        vhat = vhat/winv1
        v1 = np.zeros(d)
        v1[ind] = vhat
        v1 = v1/(np.sqrt((np.sum(v1*v1)))) #v_new
        
        # Updating u
        z = X @ v1
        winu = np.abs(z)**gamu
        sigsq = np.abs(SST - np.sum(z*z))/(n*d-n)
        cand = z*winu
        delt = np.sort(np.append(np.abs(cand),0))
        delt_uniq = np.unique(delt)
        Bu = np.ones(len(delt_uniq)-1)*float("inf")
        ind = np.where(winu > 10^(-8))
        cand1 = cand[ind]
        winu1 = winu[ind]
        for i in range(len(Bu)):
            uhat = thred(cand1,delta = delt_uniq[i])
            uhat = uhat/winu1
            ushrink = np.zeros(n)
            ushrink[ind] = uhat
            Bu[i] = np.sum((X - ushrink[:,None] @ v1[None,:])**2)/sigsq + np.sum(uhat!=0)*np.log(n*d)
        Iu = min(np.where(Bu==np.min(Bu)))
        th = delt_uniq[Iu]
        uhat = thred(cand1,delta = th)
        uhat = uhat/winu1
        u1 = np.zeros(n)
        u1[ind] =  uhat
        u1 = u1/((np.sum(u1*u1))**0.5)
        
        
        ud = np.sqrt(np.sum((u0-u1)*(u0-u1)))
        vd = np.sqrt(np.sum((v0-v1)*(v0-v1)))
        
        if count > niter :
            print("Fail to converge! Increase the niter!")
            break
        
        u0 = u1
        v0 = v1
        
    s = u1[None, :] @ X @ v1[:, None] #ssvd layer is suv.T
    return u1, v1, s, count

### Profiling

In [None]:
%prun ssvd(X)

### Just In Time Compilation

In [13]:
@jit
def thred(z,delta):
    return np.sign(z)*(np.abs(z)>=delta)*(np.abs(z)-delta)

def ssvd_jit(X,gamu = 2, gamv =2, merr = 10**(-4), niter = 100):
    n,d = X.shape
    #initial value of u and v
    U,s,VT = la.svd(X,full_matrices=False)
    u0 = U[:,0]
    v0 = VT.T[:,0]
    
    ud = 1
    vd = 1
    count = 0
    SST  = np.sum(X*X)
    
    while(ud>merr or vd>merr):
        count = count + 1
        
        # Update v
        z = X.T @ u0
        winv = np.abs(z)**gamv # weight inverse
        sigsq = np.abs(SST-np.sum(z*z))/(n*d-d)
        
        cand = z*winv  #candidate lambda
        delt = np.sort(np.append(np.abs(cand),0))
        delt_uniq = np.unique(delt)
        Bv = np.ones(len(delt_uniq)-1)*float("inf")
        
        ind = np.where(winv>10^(-8))
        cand1 = cand[ind]
        winv1 = winv[ind]
        for i in range(len(Bv)):
            vhat= thred(cand1,delta = delt_uniq[i])
            vhat = vhat/winv1
            vshrink = np.zeros(d)
            vshrink[ind] = vhat
            Bv[i] = np.sum((X - u0[:,None] @ vshrink[None,:])**2)/sigsq + np.sum(vhat!=0)*np.log(n*d)
        
        Iv = min(np.where(Bv== np.min(Bv))) #position of min BIC
        th = delt_uniq[Iv] #best lambda in this iteration
        vhat = thred(cand1,delta = th)
        vhat = vhat/winv1
        v1 = np.zeros(d)
        v1[ind] = vhat
        v1 = v1/(np.sqrt((np.sum(v1*v1)))) #v_new
        
        # Updating u
        z = X @ v1
        winu = np.abs(z)**gamu
        sigsq = np.abs(SST - np.sum(z*z))/(n*d-n)
        cand = z*winu
        delt = np.sort(np.append(np.abs(cand),0))
        delt_uniq = np.unique(delt)
        Bu = np.ones(len(delt_uniq)-1)*float("inf")
        ind = np.where(winu > 10^(-8))
        cand1 = cand[ind]
        winu1 = winu[ind]
        for i in range(len(Bu)):
            uhat = thred(cand1,delta = delt_uniq[i])
            uhat = uhat/winu1
            ushrink = np.zeros(n)
            ushrink[ind] = uhat
            Bu[i] = np.sum((X - ushrink[:,None] @ v1[None,:])**2)/sigsq + np.sum(uhat!=0)*np.log(n*d)
        Iu = min(np.where(Bu==np.min(Bu)))
        th = delt_uniq[Iu]
        uhat = thred(cand1,delta = th)
        uhat = uhat/winu1
        u1 = np.zeros(n)
        u1[ind] =  uhat
        u1 = u1/((np.sum(u1*u1))**0.5)
        
        
        ud = np.sqrt(np.sum((u0-u1)*(u0-u1)))
        vd = np.sqrt(np.sum((v0-v1)*(v0-v1)))
        
        if count > niter :
            print("Fail to converge! Increase the niter!")
            break
        
        u0 = u1
        v0 = v1
        
    s = u1[None, :] @ X @ v1[:, None] #ssvd layer is suv.T
    return u1, v1, s, count

In [9]:
u_ = np.arange(10,2,-1)
r1 = np.ones(17)*2
r2 = np.zeros(75)
u = np.concatenate([u_,r1,r2])
u = u/np.linalg.norm(u)
v_ = np.array([10,-10,8,-8,5,-5,3,3,3,3,3,-3,-3,-3,-3,-3])
r3 = np.zeros(34)
v = np.concatenate([v_,r3])
v = v/np.linalg.norm(v)
s = 50
X_star = s* u.reshape(len(u),1)@v.reshape(1,len(v))
X = X_star+np.random.normal(0,1,100*50).reshape(100,50)

#### simulation dataset

In [15]:
%timeit ssvd(X)

72.8 ms ± 4.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [16]:
%timeit ssvd_jit(X)

65.1 ms ± 24.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Multiprocessing

In [None]:
def calculate_Bv(X,delta,cand1,winv1,sigsq,u0,ind):
    n,d = X.shape
    temp2 = np.sign(cand1)*(np.abs(cand1)>=delta)*(np.abs(cand1)-delta)
    temp2 = temp2/winv1
    temp3 = np.zeros(d)
    temp3[ind] = temp2
    Bv= np.sum((X - u0[:,None] @ temp3[None,:])**2)/sigsq + np.sum(temp2!=0)*np.log(n*d)
    return Bv


def calculate_Bu(X,delta,cand1,winu1,sigsq,v1,ind):
    n,d = X.shape
    temp2 = np.sign(cand1)*(np.abs(cand1)>=delta)*(np.abs(cand1)-delta)
    temp2 = temp2/winu1
    temp3 = np.zeros(n)
    temp3[ind] = temp2
    Bu = np.sum((X - temp3[:,None] @ v1[None,:])**2)/sigsq + np.sum(temp2!=0)*np.log(n*d)
    return Bu


def ssvd_para(X,gamu = 2, gamv =2, merr = 10**(-4), niter = 100):
    n,d = X.shape
    #initial value of u and v
    U,s,VT = la.svd(X,full_matrices=False)
    u0 = U[:,0]
    v0 = VT.T[:,0]
    
    ud = 1
    vd = 1
    count = 0
    SST  = np.sum(X*X)
    
    while(ud>merr or vd>merr):
        count = count + 1
        
        # Update v
        z = X.T @ u0
        winv = np.abs(z)**gamv # weight inverse
        sigsq = np.abs(SST-np.sum(z*z))/(n*d-d)
        
        cand = z*winv  #candidate lambda
        delt = np.sort(np.append(np.abs(cand),0))
        delt_uniq = np.unique(delt)
        Bv = np.ones(len(delt_uniq)-1)*float("inf")
        
        ind = np.where(winv>10^(-8))
        cand1 = cand[ind]
        winv1 = winv[ind]
        args = []
        for i in range(len(Bv)):
            args.append((X,delt_uniq[i],cand1,winv1,sigsq,u0,ind)) 
        
        with Pool(processes=4) as pool:
            Bv = pool.starmap(calculate_Bv,args)
            
        
        Iv = min(np.where(Bv== np.min(Bv))) #min BIC
        th = delt_uniq[Iv]
        temp2 = thred(cand1,delta = th)
        temp2 = temp2/winv1
        v1 = np.zeros(d)
        v1[ind] = temp2
        v1 = v1/(np.sqrt((np.sum(v1*v1)))) #v_new
        
        # Updating u
        z = X @ v1
        winu = np.abs(z)**gamu
        sigsq = np.abs(SST - np.sum(z*z))/(n*d-n)
        cand = z*winu
        delt = np.sort(np.append(np.abs(cand),0))
        delt_uniq = np.unique(delt)
        Bu = np.ones(len(delt_uniq)-1)*float("inf")
        ind = np.where(winu > 10^(-8))
        cand1 = cand[ind]
        winu1 = winu[ind]
        
        args = []
        for i in range(len(Bu)):
            args.append((X,delt_uniq[i],cand1,winu1,sigsq,v1,ind)) 
        
        with Pool(processes=4) as pool:
            Bu = pool.starmap(calculate_Bu,args)
        
        Iu = min(np.where(Bu==np.min(Bu)))
        th = delt_uniq[Iu]
        temp2 = thred(cand1,delta = th)
        temp2 = temp2/winu1
        u1 = np.zeros(n)
        u1[ind] =  temp2
        u1 = u1/((np.sum(u1*u1))**0.5)
        
        
        ud = np.sqrt(np.sum((u0-u1)*(u0-u1)))
        vd = np.sqrt(np.sum((v0-v1)*(v0-v1)))
        
        if count > niter :
            print("Fail to converge! Increase the niter!")
            break
        
        u0 = u1
        v0 = v1
        
    s = u1[None, :] @ X @ v1[:, None] #ssvd layer is suv.T
    return u1, v1, s, count

### IPyParallel

In [22]:
def thred(z,delta):
    return np.sign(z)*(np.abs(z)>=delta)*(np.abs(z)-delta)

def calculate_Bv(delta):
    import numpy as np
    global X,cand1,winv1,sigsq,u0,ind
    n,d = X.shape
    temp2 = np.sign(cand1)*(np.abs(cand1)>=delta)*(np.abs(cand1)-delta)
    temp2 = temp2/winv1
    temp3 = np.zeros(d)
    temp3[ind] = temp2
    Bv= np.sum((X - u0[:,None] @ temp3[None,:])**2)/sigsq + np.sum(temp2!=0)*np.log(n*d)
    return Bv


def calculate_Bu(delta):
    import numpy as np
    global X,cand1,winu1,sigsq,v1,ind
    n,d = X.shape
    temp2 = np.sign(cand1)*(np.abs(cand1)>=delta)*(np.abs(cand1)-delta)
    temp2 = temp2/winu1
    temp3 = np.zeros(n)
    temp3[ind] = temp2
    Bu = np.sum((X - temp3[:,None] @ v1[None,:])**2)/sigsq + np.sum(temp2!=0)*np.log(n*d)
    return Bu


def ssvd_ipy(X,gamu = 2, gamv =2, merr = 10**(-4), niter = 100):
    n,d = X.shape
    #initial value of u and v
    U,s,VT = la.svd(X,full_matrices=False)
    u0 = U[:,0]
    v0 = VT.T[:,0]
    
    ud = 1
    vd = 1
    count = 0
    SST  = np.sum(X*X)
    
    while(ud>merr or vd>merr):
        count = count + 1
        
        # Update v
        z = X.T @ u0
        winv = np.abs(z)**gamv # weight inverse
        sigsq = np.abs(SST-np.sum(z*z))/(n*d-d)
        
        cand = z*winv  #candidate lambda
        delt = np.sort(np.append(np.abs(cand),0))
        delt_uniq = np.unique(delt)
        Bv = np.ones(len(delt_uniq)-1)*float("inf")
        
        ind = np.where(winv>10^(-8))
        cand1 = cand[ind]
        winv1 = winv[ind]
        #Bv_p = partial(calculate_Bv,X=X,cand1=cand1,winv1=winv1,sigsq=sigsq,u0=u0,ind=ind)
        #args = []
        #for i in range(len(Bv)):
        #    args.append((X,delt_uniq[i],cand1,winv1,sigsq,u0,ind)) 
        dv=rc[:]
        dv.push(dict(X=X,cand1=cand1,winv1=winv1,sigsq=sigsq,u0=u0,ind=ind))
        
        #Bv = dv.map(lambda x: calculate_Bv(x[0],x[1],x[2],x[3],x[4],x[5]),args)
        Bv = dv.map_sync(calculate_Bv,delt_uniq)
        
        Iv = min(np.where(Bv== np.min(Bv))) #min BIC
        th = delt_uniq[Iv]
        temp2 = thred(cand1,delta = th)
        temp2 = temp2/winv1
        v1 = np.zeros(d)
        v1[ind] = temp2
        v1 = v1/(np.sqrt((np.sum(v1*v1)))) #v_new
        
        # Updating u
        z = X @ v1
        winu = np.abs(z)**gamu
        sigsq = np.abs(SST - np.sum(z*z))/(n*d-n)
        cand = z*winu
        delt = np.sort(np.append(np.abs(cand),0))
        delt_uniq = np.unique(delt)
        Bu = np.ones(len(delt_uniq)-1)*float("inf")
        ind = np.where(winu > 10^(-8))
        cand1 = cand[ind]
        winu1 = winu[ind]
        for i in range(len(Bu)):
            uhat = thred(cand1,delta = delt_uniq[i])
            uhat = uhat/winu1
            ushrink = np.zeros(n)
            ushrink[ind] = uhat
            Bu[i] = np.sum((X - ushrink[:,None] @ v1[None,:])**2)/sigsq + np.sum(uhat!=0)*np.log(n*d)
        
        Iu = min(np.where(Bu==np.min(Bu)))
        th = delt_uniq[Iu]
        temp2 = thred(cand1,delta = th)
        temp2 = temp2/winu1
        u1 = np.zeros(n)
        u1[ind] =  temp2
        u1 = u1/((np.sum(u1*u1))**0.5)
        
        
        ud = np.sqrt(np.sum((u0-u1)*(u0-u1)))
        vd = np.sqrt(np.sum((v0-v1)*(v0-v1)))
        
        if count > niter :
            print("Fail to converge! Increase the niter!")
            break
        
        u0 = u1
        v0 = v1
        
    s = u1[None, :] @ X @ v1[:, None] #ssvd layer is suv.T
    return u1, v1, s, count

#### lung cancer dataset

In [17]:
lung = np.loadtxt("lung data.txt")
X = lung.T

In [24]:
%%time
ssvd(X)

CPU times: user 29min 46s, sys: 1h 29min 44s, total: 1h 59min 30s
Wall time: 15min 33s


(array([-0.18018641, -0.2111319 , -0.10737516, -0.15574824, -0.15112323,
        -0.16025416, -0.19018526, -0.16695614, -0.19288482, -0.1101655 ,
        -0.13320652, -0.16567419, -0.18539703, -0.13521102, -0.17296825,
        -0.11586789, -0.19550012, -0.06831533, -0.17865369, -0.19547649,
         0.09094302,  0.04679269,  0.08135216,  0.0319133 ,  0.05270603,
         0.0326062 ,  0.08880033,  0.07646915,  0.07221645,  0.04541018,
         0.06393339,  0.0096247 , -0.01926316,  0.16077355,  0.15437588,
         0.15564448,  0.13808241,  0.14590629,  0.15895225,  0.16920658,
         0.17449384,  0.15142279,  0.1758557 ,  0.13815952,  0.14764779,
         0.14680345,  0.1735995 ,  0.14234025,  0.12499649,  0.1455915 ,
         0.07617165, -0.03848219, -0.04149155,  0.0123531 , -0.        ,
        -0.11083854]),
 array([0.       , 0.0126103, 0.       , ..., 0.       , 0.       ,
        0.       ]),
 array([[197.2569908]]),
 6)

In [23]:
%%time
ssvd_ipy(X)

CPU times: user 7.32 s, sys: 18 s, total: 25.3 s
Wall time: 9min 54s


(array([-0.18018641, -0.2111319 , -0.10737516, -0.15574824, -0.15112323,
        -0.16025416, -0.19018526, -0.16695614, -0.19288482, -0.1101655 ,
        -0.13320652, -0.16567419, -0.18539703, -0.13521102, -0.17296825,
        -0.11586789, -0.19550012, -0.06831533, -0.17865369, -0.19547649,
         0.09094302,  0.04679269,  0.08135216,  0.0319133 ,  0.05270603,
         0.0326062 ,  0.08880033,  0.07646915,  0.07221645,  0.04541018,
         0.06393339,  0.0096247 , -0.01926316,  0.16077355,  0.15437588,
         0.15564448,  0.13808241,  0.14590629,  0.15895225,  0.16920658,
         0.17449384,  0.15142279,  0.1758557 ,  0.13815952,  0.14764779,
         0.14680345,  0.1735995 ,  0.14234025,  0.12499649,  0.1455915 ,
         0.07617165, -0.03848219, -0.04149155,  0.0123531 , -0.        ,
        -0.11083854]),
 array([0.       , 0.0126103, 0.       , ..., 0.       , 0.       ,
        0.       ]),
 array([[197.2569908]]),
 6)