# Sundivide $D$

Notebook to optimize 

$$f(s_r,s_c)= \|H_a\|_* + \|H_b\|_*$$

Where $s_c$ and $s_r$ are vectors that determine if a column/row of $D$ is in the set $a$ or in the set $b$.
The matrices $H_a$ and $H_b$ 


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import block_diag
from scipy.linalg import svdvals

In [None]:
def segment_matrix(A,normalize = True):
    """
    Matrix segemtation using spectral clustering
    
    returns
        s: boolena vector, True if collumn in second part 
    """
    
    # Caclulate Laplacian
    L = -np.abs(A.T@A) 
    # Set diagonal
    np.fill_diagonal(L,0)
    np.fill_diagonal(L,-np.sum(L,axis=0))

    if normalize:
        #normalize L
        d = np.sqrt(1/np.diag(L))
        L = d.reshape(1,-1)*L*d.reshape(-1,1)
        

    w, v = np.linalg.eig(L)

    #order eigenvalues
    o = np.argsort(w.real)
    v = v[:,o]

    return v[:,1].real>0

def get_initial(A):
    s_col = segment_matrix(A)
    s_row = segment_matrix(A.T)
    
    return (s_col,s_row)

In [None]:
def norm_add(U,s,Vt,c):
    m = U.T@c
    p = c-U@m
    K = np.block([[np.diag(s),m.reshape(-1,1)],[np.zeros((1,len(s))),np.linalg.norm(p)]])
    return np.sum(svdvals(K,overwrite_a=True,check_finite=False))

def norm_remove(U,s,Vt,c,i_c):
    b = np.zeros(Vt.shape[1])
    b[i_c]=-1
    m = U.T@c
    p = c-U@m
    n = Vt@b#n = Vt[:,i_c]
    q = b-Vt.T@n

    u = np.hstack([m,np.linalg.norm(p)])
    v = np.hstack([n,np.linalg.norm(q)])
    
    K = u.reshape(-1,1)@v.reshape(1,-1)
    np.fill_diagonal(K, u*v+np.hstack([s,0]))
    #K = np.diag(np.hstack([s,0]))+u.reshape(-1,1)@v.reshape(1,-1)
    
    return np.linalg.norm(K,'nuc')#np.sum(svdvals(K,overwrite_a=True,check_finite=False))

def get_f_change(Ua,sa,Vta,Ub,sb,Vtb,Ma,Mb,v):
    """
    We have the matriies:
    M = [Ma \\ Mb]
    
    
    And a bolena vector v, that determines the indexing
    
    f(v) = ||Ma[:,v]||_* + ||Mb[:,not(v)]||_*
    
    This function computes f(v) for the vectors v' where the i-th element of v is inverted
    
    The svd of Ma and Mb are given
    
    Ua,sa,Vta = svd(Ma[:,v])
    Ub,sb,Vtb = svd(Mb[:,not(v)])
    
    It is also possible to add column that are not subject to change. 
    These have to be added at the end to keep the indexing consistent
    
    
    """
    f_change = np.zeros(len(v))
    i_a = 0 #indices to get where the column is if we try to remove it
    i_b = 0
    for i in range(len(v)):
        if v[i]: #collumn is in set a
            f_change[i]= norm_add(Ub,sb,Vtb,Mb[:,i]) + norm_remove(Ua,sa,Vta,Ma[:,i],i_a)
            i_a += 1 
        else:
            f_change[i]= norm_add(Ua,sa,Vta,Ma[:,i]) + norm_remove(Ub,sb,Vtb,Mb[:,i],i_b)
            i_b += 1
    return f_change


In [None]:
def f(Ma,Mb,v):
    """
    We have the matriies:
    M = [Ma \\ Mb]
    
    
    And a bolena vector v, that determines the indexing
    
    f(v) = ||Ma[:,v]||_* + ||Mb[:,not(v)]||_*
    
    
    """
    return np.linalg.norm(Ma[:,v],'nuc')+np.linalg.norm(Mb[:,~v],'nuc')

In [None]:
#some tests:

N = 8
M = 7
A = np.random.rand(N,M)

#cut out some collumn that we can add back later
ind_b = np.ones(M,dtype=bool)
i = 4
ind_b[i] = False

U,s,Vt = np.linalg.svd(A[:,ind_b],full_matrices=False) 
print(norm_add(U,s,Vt,A[:,i])-np.linalg.norm(A,'nuc'))

#now do the inverse
U,s,Vt = np.linalg.svd(A,full_matrices=False) 
print(norm_remove(U,s,Vt,A[:,i],i)-np.linalg.norm(A[:,ind_b],'nuc'))

#now change test the second function
Ma = A[:4]
Mb = A[4:]

v = np.ones(M,dtype=bool)
v[:3]=0

Ua,sa,Vta = np.linalg.svd(Ma[:,v],full_matrices=False) 
Ub,sb,Vtb = np.linalg.svd(Mb[:,~v],full_matrices=False) 

f_ch = get_f_change(Ua,sa,Vta,Ub,sb,Vtb,Ma,Mb,v)
for i in range(len(v)):
    v_prime = v.copy()
    v_prime[i] = not v_prime[i]
    print(np.linalg.norm(Ma[:,v_prime],'nuc')+np.linalg.norm(Mb[:,~v_prime],'nuc')-f_ch[i])

In [None]:
N = 8
M = 10
A = np.random.rand(N,M)


In [None]:
s_col,s_row = get_initial(A)
print(s_col)
print(s_row)

In [None]:
Ma = A[s_row]
Mb = A[~s_row]

f_start = f(Ma,Mb,s_col)

print("start f(v)=",f_start)
Ua,sa,Vta = np.linalg.svd(Ma[:,s_col],full_matrices=False) 
Ub,sb,Vtb = np.linalg.svd(Mb[:,~s_col],full_matrices=False) 
fs = get_f_change(Ua,sa,Vta,Ub,sb,Vtb,Ma,Mb,s_col)
fs-f_start

In [None]:
#transposed problem
Mat = A.T[s_col]
Mbt = A.T[~s_col]

fst = get_f_change(Vta.T,sa,Ua.T,Vtb.T,sb,Ub.T,Mat,Mbt,s_row)
fst-f_start

In [None]:
n_neg = np.count_nonzero(fs<f_start)
ordering = np.argsort(fs)
n_neg_a = np.count_nonzero(s_col[ordering[:n_neg]])
n_neg_b = n_neg-n_neg_a
n_flip = min(n_neg_a,n_neg_b)
flip_a=ordering[s_col[ordering]][:n_flip]
flip_b=ordering[~s_col[ordering]][:n_flip]
print("n_flip",n_flip)
print("flipa",flip_a)
print("flipb",flip_b)
print(fs[flip_a]-f_start)
print(fs[flip_b]-f_start)
s_col[flip_a] = False
s_col[flip_b] = True

In [None]:
n_neg = np.count_nonzero(fst<f_start)
ordering = np.argsort(fst)
n_neg_a = np.count_nonzero(s_row[ordering[:n_neg]])
n_neg_b = n_neg-n_neg_a
n_flip = min(n_neg_a,n_neg_b)
flip_a=ordering[s_row[ordering]][:n_flip]
flip_b=ordering[~s_row[ordering]][:n_flip]
print("n_flip",n_flip)
print("flipa",flip_a)
print("flipb",flip_b)
print(fst[flip_a]-f_start)
print(fst[flip_b]-f_start)
s_row[flip_a] = False
s_row[flip_b] = True

In [None]:
s_col

In [None]:
s_row

## Combine them to algorithm

In [None]:
A = np.random.rand(100,120)

f_base = np.linalg.norm(A[:A.shape[0]//2,:A.shape[1]//2],'nuc') \
    +np.linalg.norm(A[A.shape[0]//2:,A.shape[1]//2:],'nuc')
print(f_base)

In [None]:
N = 50


fs_list = np.zeros(N+1)
s_cols=np.zeros((N+1,A.shape[1])) 
s_rows=np.zeros((N+1,A.shape[0])) 


s_col,s_row = get_initial(A)
Ma = A[s_row]
Mb = A[~s_row]

fs_list[0] = f(Ma,Mb,s_col)
s_cols[0]=s_col
s_rows[0]=s_row


for n in range(N):
    Ma = A[s_row]
    Mb = A[~s_row]

    f_ref = f(Ma,Mb,s_col)


    
    Ua,sa,Vta = np.linalg.svd(Ma[:,s_col],full_matrices=False) 
    Ub,sb,Vtb = np.linalg.svd(Mb[:,~s_col],full_matrices=False) 
    fs = get_f_change(Ua,sa,Vta,Ub,sb,Vtb,Ma,Mb,s_col)


    #transposed problem
    Mat = A.T[s_col]
    Mbt = A.T[~s_col]
    fst = get_f_change(Vta.T,sa,Ua.T,Vtb.T,sb,Ub.T,Mat,Mbt,s_row)

    #flip some columns
    n_neg = np.count_nonzero(fs<f_ref)
    ordering = np.argsort(fs)
    n_neg_a = np.count_nonzero(s_col[ordering[:n_neg]])
    n_neg_b = n_neg-n_neg_a
    n_flip_c = min(n_neg_a,n_neg_b)
    flip_a=ordering[s_col[ordering]][:n_flip_c]
    flip_b=ordering[~s_col[ordering]][:n_flip_c]

    s_col[flip_a] = False
    s_col[flip_b] = True


    #flip some rows
    n_neg = np.count_nonzero(fst<f_ref)
    ordering = np.argsort(fst)
    n_neg_a = np.count_nonzero(s_row[ordering[:n_neg]])
    n_neg_b = n_neg-n_neg_a
    n_flip_r = min(n_neg_a,n_neg_b)
    flip_a=ordering[s_row[ordering]][:n_flip_r]
    flip_b=ordering[~s_row[ordering]][:n_flip_r]

    s_row[flip_a] = False
    s_row[flip_b] = True
    
    fs_list[n+1] = f(Ma,Mb,s_col)
    s_cols[n+1]=s_col
    s_rows[n+1]=s_row
    
    if n_flip_r==0 and n_flip_c==0:
        break
print(n)

In [None]:
plt.plot(fs_list)
plt.hlines(f_base,0,n+1)

In [None]:
plt.spy(s_cols[:n])

In [None]:
plt.spy(s_rows[:n])

In [None]:
np.count_nonzero(s_col)

In [None]:
np.count_nonzero(s_row)