# Clustering collums and rows

The idea is to cluster collumns and rows

Here we have the objective function 



In [None]:
import numpy as np
import tvsclib.utils as utils
import Split
import matplotlib.pyplot as plt
from tvsclib.strict_system import StrictSystem

import torchvision.models as models
import torch
import scipy.stats 

import scipy.linalg as linalg

In [None]:
def get_mobilenet_target_mats():
    target_mats = []
    # Load the model
    model = models.mobilenet_v2(pretrained=True)
    # Put moel into eval mode
    model.eval()
    for layer in model.classifier:
        if isinstance(layer, torch.nn.Linear):
            # Obtain the weights of this layer
            weights = layer.weight.detach().numpy()
            target_mats.append(weights)
    return target_mats

In [None]:
T = np.random.rand(32,32)
#T = mats = get_mobilenet_target_mats()[0]
sys = Split.initial_mixed(T)
utils.show_system(sys,mark_D=False)
utils.check_dims(sys)

# Compute gramians/weights

In [None]:
#get matricies
k = 0
stage_c=sys.causal_system.stages[k]
stage_a=sys.anticausal_system.stages[k]
Ac = stage_c.A_matrix
Bc = stage_c.B_matrix
Cc = stage_c.C_matrix

Aa = stage_a.A_matrix
Ba = stage_a.B_matrix
Ca = stage_a.C_matrix

D = stage_c.D_matrix

#dims of states

(d_out_c,d_in_c) = Ac.shape

(d_out_a,d_in_a) = Aa.shape

X = np.block([[np.zeros((d_out_a,d_in_c)),Ba,Aa ],
              [Cc,D,Ca],
              [Ac,Bc,np.zeros((d_out_c,d_in_a))]
    
])

plt.matshow(X)

s_c = np.zeros(X.shape[1],dtype=bool)
s_c[:d_in_c+D.shape[1]//2]=1
s_r = np.zeros(X.shape[0],dtype=bool)
s_r[d_out_a+D.shape[0]//2:]=1

In [None]:
X_l = X[:,s_c]
X_r = X[:,~s_c]

X_l = X_l/np.linalg.norm(X_l,axis=1).reshape(-1,1)
X_r = X_r/np.linalg.norm(X_r,axis=1).reshape(-1,1)

X_t = X[~s_r]
X_b = X[s_r]

X_t = X_t/np.linalg.norm(X_t,axis=0).reshape(1,-1)
X_b = X_b/np.linalg.norm(X_b,axis=0).reshape(1,-1)

In [None]:
X_ = X_b #top or bottom
print(np.diag(X_.T@X_))
plt.matshow(X_.T@X_)

In [None]:
X_ = X_r  #left or right
print(np.diag(abs(X_@X_.T)))
plt.matshow(X_@X_.T)

In [None]:
plt.matshow(1-abs(X_@X_.T))

In [None]:
Adj = 1-abs(X_b.T@X_b)
np.sum(Adj[:,s_c],axis=1)

In [None]:
Adj = 1-abs(X_t.T@X_t)
np.sum(Adj[:,~s_c],axis=1)

In [None]:
Adj = 1-abs(X_l@X_l.T)
np.sum(Adj[:,s_r],axis=1)

In [None]:
Adj = 1-abs(X_r@X_r.T)
np.sum(Adj[:,~s_r],axis=1)

## Some helping functions

In [None]:
def f(X,s_col,s_row):
    return np.linalg.norm(X[s_row][:,s_col],'nuc')+np.linalg.norm(X[~s_row][:,~s_col],'nuc')

def compute_sigmasf(X,s_col,s_row):
    return np.linalg.svd(X[s_row][:,s_col],compute_uv=False), np.linalg.svd(X[~s_row][:,~s_col],compute_uv=False)

def show_matrices(X,s_col,s_row):
    plt.matshow(X[s_row][:,s_col],fignum=1)

    plt.matshow(X[~s_row][:,~s_col],fignum=2)

f_reg_row = lambda x: -gamma*x**3
f_reg_col = lambda x: -gamma*x**3

gamma = 1e2 #maxbee have two different regularizations for rows and collumns

## Iteration

In [None]:
def segment_matrix(stage_causal,stage_anticausal,N=50):
    
    Ac = stage_causal.A_matrix
    Bc = stage_causal.B_matrix
    Cc = stage_causal.C_matrix

    Aa = stage_anticausal.A_matrix
    Ba = stage_anticausal.B_matrix
    Ca = stage_anticausal.C_matrix

    D = stage_causal.D_matrix
    
    #regularization vector
    v_reg_col = f_reg_col(np.linspace(-1,1,D.shape[1]))
    v_reg_row = f_reg_row(np.linspace(-1,1,D.shape[0]))

    #dims of states
    (d_out_c,d_in_c) = Ac.shape
    (d_out_a,d_in_a) = Aa.shape

    #setup matrix
    X = np.block([[np.zeros((d_out_a,d_in_c)),Ba,Aa ],
                  [Cc,D,Ca],
                  [Ac,Bc,np.zeros((d_out_c,d_in_a))]
    
    ])

    #initialize segmentation
    s_c = np.zeros(X.shape[1],dtype=bool)
    s_c[:d_in_c+D.shape[1]//2]=1
    s_r = np.zeros(X.shape[0],dtype=bool)
    s_r[d_out_a+D.shape[0]//2:]=1
    
    
    fs = np.zeros(N+1)
    s_cols=np.zeros((N+1,X.shape[1]),dtype=bool) 
    s_rows=np.zeros((N+1,X.shape[0]),dtype=bool) 
    
    s_cols[0]=s_c
    s_rows[0]=s_r
    fs[0]=f(X,s_c,s_r)
    
    for n in range(N):
        #normlize
        X_l = X[:,s_c]
        X_r = X[:,~s_c]

        X_l = X_l/np.linalg.norm(X_l,axis=1).reshape(-1,1)
        X_r = X_r/np.linalg.norm(X_r,axis=1).reshape(-1,1)

        X_t = X[~s_r]
        X_b = X[s_r]

        X_t = X_t/np.linalg.norm(X_t,axis=0).reshape(1,-1)
        X_b = X_b/np.linalg.norm(X_b,axis=0).reshape(1,-1)
    
        #compute cost matricies and cost vectors 
        #columns
        Adj_b = 1-abs(X_b.T@X_b)
        Adj_t = 1-abs(X_t.T@X_t)
        f_col = np.sum(Adj_b[:,s_c],axis=1)/np.count_nonzero(s_c) - np.sum(Adj_t[:,~s_c],axis=1)/np.count_nonzero(~s_c)
        #f_col = np.sum(Adj_b[:,s_c],axis=1) - np.sum(Adj_t[:,~s_c],axis=1)
        
        Adj_l = 1-abs(X_l@X_l.T)
        Adj_r = 1-abs(X_r@X_r.T)
        f_row = np.sum(Adj_l[:,s_r],axis=1)/np.count_nonzero(s_r) - np.sum(Adj_r[:,~s_r],axis=1)/np.count_nonzero(~s_r)
        #f_row = np.sum(Adj_l[:,s_r],axis=1) - np.sum(Adj_r[:,~s_r],axis=1)
        

        
        #try to minimize -> choose smaller one
        
        ord_c = d_in_c+np.argsort(f_col[d_in_c:X.shape[1]-d_in_a])
        ord_r = d_out_a+np.argsort(f_row[d_out_a:X.shape[0]-d_out_c])
        s_c[d_in_c:X.shape[1]-d_in_a]=0
        s_r[d_out_a:X.shape[0]-d_out_c]=0
        v_c =ord_c[f_col[ord_c]<v_reg_col]
        v_r =ord_r[f_col[ord_r]<v_reg_row]
        s_c[v_c]=1
        s_r[v_r]=1
        
        #print(f_col)
        #print(f_row)
        
        #s_c[:]=0
        #s_c[f_col<0]=1
        #s_r[:]=0
        #s_r[f_row<0]=1
        
        
        fs[n+1] = f(X,s_c,s_r)
        s_cols[n+1]=s_c
        s_rows[n+1]=s_r

    return fs,s_cols,s_rows,X,s_c[d_in_c:X.shape[1]-d_in_a],s_r[d_out_a:X.shape[0]-d_out_c]
        


In [None]:
sys = Split.initial_mixed(T)
stage_c=sys.causal_system.stages[k]
stage_a=sys.anticausal_system.stages[k]
fs,s_cols,s_rows,X,s_c,s_r = segment_matrix(stage_c,stage_a,N=20)

In [None]:
s_start_c,s_start_a = compute_sigmasf(X,s_cols[0],s_rows[0])
s_end_c,s_end_a = compute_sigmasf(X,s_cols[-1],s_rows[-1])

plt.subplot(1,2,1)
plt.plot(s_start_c,label="start")
plt.plot(s_end_c,label="end")
plt.legend()
plt.subplot(1,2,2)
plt.plot(s_start_a,label="start")
plt.plot(s_end_a,label="end")

In [None]:
plt.spy(s_cols)
print(np.count_nonzero(s_cols[-1])/len(s_cols[-1]))

In [None]:
plt.spy(s_rows)
print(np.count_nonzero(s_cols[-1])/len(s_cols[-1]))

In [None]:
plt.plot(fs)

## Some helping function to apply permutation and store it

In [None]:
def get_permutations(s_c,s_r):
    p_col = np.hstack([np.nonzero(s_c),np.nonzero(~s_c)]).reshape(-1)
    p_row = np.hstack([np.nonzero(~s_r),np.nonzero(s_r)]).reshape(-1)
    i_in =np.count_nonzero(s_c)
    i_out=np.count_nonzero(s_r)
    return p_col,p_row,i_in,i_out

def permute_stage(stage,p_col,p_row):
    stage.B_tilde = stage.B_tilde[:,p_col]
    stage.C_tilde = stage.C_tilde[p_row,:]
    stage.D_matrix = stage.D_matrix[:,p_col][p_row]
    
def collect_permutations(P_col,P_row,k,p_col,p_row,system):
    """
    Function to collect the permutations in P_col and P_row
    
    P_col:    total permutation of columns
    P_row:    total permutation of columns
    l:        index of stage
    p_col:    new collumn permuation
    p_row:    new row permutation
    """
    
    dims_in = system.dims_in
    dims_out = system.dims_out
    
    I = np.sum(dims_in[:k]).astype(int)
    P_col[I:I+dims_in[k]]=P_col[I:I+dims_in[k]][p_col]
    I = np.sum(dims_out[:k]).astype(int)
    P_row[I:I+dims_out[k]]=P_row[I:I+dims_out[k]][p_row]

In [None]:
N_col=5
N_row=4
s_c = np.array([0,1]*N_col,dtype=bool)
s_r = np.array([0,1]*N_row,dtype=bool)
A = np.zeros((1,1))
B = np.zeros((1,2*N_col))
C = np.zeros((2*N_row,1))
D = np.zeros((2*N_row,2*N_col))

B[:,s_c]=np.arange(1,N_col+1)
B[:,~s_c]=np.arange(N_col+1,2*N_col+1)

C[s_r]=np.arange(1,N_row+1).reshape(-1,1)
C[s_r]=np.arange(N_row+1,2*N_row+1).reshape(-1,1)

index = np.arange(4*N_col*N_row,dtype=int).reshape(2*N_row,2*N_col)
D.reshape(-1)[index[s_r][:,s_c].reshape(-1)]=(np.arange(1,N_col+1).reshape(1,-1)*np.arange(1,N_row+1).reshape(-1,1)).reshape(-1)
D.reshape(-1)[index[~s_r][:,~s_c].reshape(-1)]=(np.arange(N_col+1,2*N_col+1).reshape(1,-1)*np.arange(N_row+1,2*N_row+1).reshape(-1,1)).reshape(-1)
#D[s_r][:,s_c]#D[~s_r][:,~s_c]

stage = Split.Stage_sigmas(A,B,C,D,np.ones(1),np.ones(1))
p_col,p_row,i_in,i_out = get_permutations(s_c,s_r)
permute_stage(stage,p_col,p_row)
display(stage.A_matrix)
display(stage.B_matrix)
display(stage.C_matrix)
display(stage.D_matrix)


system = StrictSystem(stages=[
    Split.Stage_sigmas(np.zeros((1,0)),np.zeros((1,1)),np.zeros((2,0)),np.zeros((2,1)),np.zeros(1),np.zeros(0)),
    stage,
    Split.Stage_sigmas(np.zeros((0,1)),np.zeros((0,1)),np.zeros((2,1)),np.zeros((2,1)),np.zeros(1),np.zeros(0))
],causal=True)
utils.show_system(system,mark_D=False)

T=linalg.block_diag(np.zeros((2,1)),D,np.zeros((2,1)))
P_col = np.arange(T.shape[1],dtype=int)
P_row = np.arange(T.shape[0],dtype=int)
collect_permutations(P_col,P_row,1,p_col,p_row,system)
display(P_col)
display(P_row)

In [None]:
plt.matshow(T[P_row][:,P_col])

## combine it to algorithm

In [None]:
dims_in =  np.array([6, 3, 5, 2])*3
dims_out = np.array([2, 5, 3, 6])*3

#create orthogonal vectors and normalize them to the size of the matix (i.e. norm(block)/size(block) = const
#Us =np.vstack([np.linalg.svd(np.random.rand(dims_out[i],dims_in[i]))[0][:,1:4]*dims_out[i] for i in range(len(dims_in))])
#Vts=np.hstack([np.linalg.svd(np.random.rand(dims_out[i],dims_in[i]))[2][1:4,:]*dims_in[i] for i in range(len(dims_in))])

#create orthogonal vectors and normalize them to the size of the matix (i.e. norm(block)/size(block) = const
Us =np.vstack([scipy.stats.ortho_group.rvs(dims_out[i])[:,:3]*dims_out[i] for i in range(len(dims_in))])
Vts=np.hstack([scipy.stats.ortho_group.rvs(dims_in[i])[:3,:]*dims_in[i] for i in range(len(dims_in))])



lower = Us[:,:1]@Vts[:1,:]
diag = Us[:,1:2]@Vts[1:2,:]
upper = Us[:,2:3]@Vts[2:3,:]
matrix = np.zeros_like(diag)
a=0;b=0
for i in range(len(dims_in)):
    matrix[a:a+dims_out[i],:b]            =lower[a:a+dims_out[i],:b]
    matrix[a:a+dims_out[i],b:b+dims_in[i]]=diag[a:a+dims_out[i],b:b+dims_in[i]]
    matrix[a:a+dims_out[i],b+dims_in[i]:] =upper[a:a+dims_out[i],b+dims_in[i]:]
    a+=dims_out[i];b+=dims_in[i]
plt.figure()
plt.matshow(matrix)

In [None]:
#T = np.random.rand(32,32)
#T =np.arange(1,32).reshape(1,-1)*np.arange(1,32).reshape(-1,1)
T = matrix
#T = mats = get_mobilenet_target_mats()[0]
sys = Split.initial_sigmas_mixed(T)
utils.show_system(sys,mark_D=False)
utils.check_dims(sys)

In [None]:
def identification_split_system(sys,N):
    
    P_col = np.arange(np.sum(sys.dims_in) ,dtype=int)
    P_row = np.arange(np.sum(sys.dims_out),dtype=int)
    
    Ps_col =np.zeros((N,P_col.size),dtype=int)
    Ps_row =np.zeros((N,P_row.size),dtype=int)
    for n in range(N):
        print(n)
        for k in range(len(sys.causal_system.stages)-1,-1,-1): #reverse ordering makes indexing easier 
            i_in =sys.causal_system.stages[k].dim_in//2
            i_out=sys.causal_system.stages[k].dim_out//2
            Split.split_sigmas_mixed(sys,k,i_in,i_out)
    

In [None]:
sys = Split.initial_sigmas_mixed(T)
identification_split_system(sys,3)
utils.check_dims(sys)
utils.show_system(sys)

In [None]:
def identification_split_clustering(sys,N):
    
    P_col = np.arange(np.sum(sys.dims_in) ,dtype=int)
    P_row = np.arange(np.sum(sys.dims_out),dtype=int)
    
    Ps_col =np.zeros((N,P_col.size),dtype=int)
    Ps_row =np.zeros((N,P_row.size),dtype=int)
    for n in range(N):
        print(n)
        for k in range(len(sys.causal_system.stages)-1,-1,-1): #reverse ordering makes indexing easier 
            stage_c=sys.causal_system.stages[k]
            stage_a=sys.anticausal_system.stages[k]
            fs,s_cols,s_rows,X,s_c,s_r = segment_matrix(stage_c,stage_a,N=10)
            
            assert len(s_c)==stage_c.dim_in ,"dims_in causal do not match s_c"
            assert len(s_r)==stage_c.dim_out,"dims_out causal do not match s_r"
            assert len(s_c)==stage_a.dim_in ,"dims_in antic do not match s_c"
            assert len(s_r)==stage_a.dim_out,"dims_out antic do not match s_r"
            p_col,p_row,i_in,i_out = get_permutations(s_c,s_r)
            permute_stage(stage_c,p_col,p_row)
            permute_stage(stage_a,p_col,p_row)
            collect_permutations(P_col,P_row,k,p_col,p_row,sys)
            
            Split.split_sigmas_mixed(sys,k,i_in,i_out)
    
    return P_col,P_row

In [None]:
gamma = 1e2
sys = Split.initial_sigmas_mixed(T)
P_col,P_row = identification_split_clustering(sys,2)
utils.check_dims(sys)
utils.show_system(sys)

In [None]:
plt.subplot(1,2,1)
plt.scatter(np.arange(len(P_col)),P_row)
plt.subplot(1,2,2)
plt.scatter(np.arange(len(P_row)),P_row)

In [None]:
sigmas_causal =[stage.s_in for stage in sys.causal_system.stages][1:]
sigmas_anticausal =[stage.s_in for stage in sys.anticausal_system.stages][:-1]
#print(sigmas_causal)
#print(sigmas_anticausal)
plt.subplot(1,2,1)
for sig in sigmas_causal:
    plt.scatter(np.arange(len(sig)),sig)
plt.subplot(1,2,2)
for sig in sigmas_anticausal:
    plt.scatter(np.arange(len(sig)),sig)

In [None]:
sigmas_anticausal

In [None]:
plt.matshow(T[P_row][:,P_col])

In [None]:
plt.matshow(T[P_row][:,P_col]-sys.to_matrix())
np.max(np.abs(T[P_row][:,P_col]-sys.to_matrix()))

In [None]:
utils.check_dims(sys)