# Clustering collums and rows

The idea is to cluster collumns and rows

Here we have the objective function 



In [None]:
import numpy as np
import tvsclib.utils as utils
import Split
import matplotlib.pyplot as plt
from tvsclib.strict_system import StrictSystem

from tvsclib.approximation import Approximation

import torchvision.models as models
import torch
import scipy.stats 

import graphs

import scipy.linalg as linalg

import plot_permutations as perm

In [None]:
def get_mobilenet_target_mats():
    target_mats = []
    # Load the model
    model = models.mobilenet_v2(pretrained=True)
    # Put moel into eval mode
    model.eval()
    for layer in model.classifier:
        if isinstance(layer, torch.nn.Linear):
            # Obtain the weights of this layer
            weights = layer.weight.detach().numpy()
            target_mats.append(weights)
    return target_mats

In [None]:
T = np.random.rand(64,64)
#T = mats = get_mobilenet_target_mats()[0]
sys = Split.initial_mixed(T)
utils.show_system(sys,mark_D=False)
utils.check_dims(sys)

## Some helping functions

In [None]:
def f(X,s_col,s_row):
    return 0

def compute_sigmasf(X,s_col,s_row):
    return np.linalg.svd(X[s_row][:,s_col],compute_uv=False), np.linalg.svd(X[~s_row][:,~s_col],compute_uv=False)

def show_matrices(X,s_col,s_row):
    plt.matshow(X[s_row][:,s_col],fignum=1)

    plt.matshow(X[~s_row][:,~s_col],fignum=2)

f_reg_row = lambda x: -gamma*x**3
f_reg_col = lambda x: -gamma*x**3

gamma = 1e4 #maxbee have two different regularizations for rows and collumns

## Iteration

In [None]:
def segment_matrix(stage_causal,stage_anticausal,N=70,initla_spectral=True): #minimize frobenius norm
    
    Ac = stage_causal.A_matrix
    Bc = stage_causal.B_matrix
    Cc = stage_causal.C_matrix

    Aa = stage_anticausal.A_matrix
    Ba = stage_anticausal.B_matrix
    Ca = stage_anticausal.C_matrix

    D = stage_causal.D_matrix
    
    #regularization vector
    v_reg_col = f_reg_col(np.linspace(-1,1,D.shape[1]))
    v_reg_row = f_reg_row(np.linspace(-1,1,D.shape[0]))

    #dims of states
    (d_out_c,d_in_c) = Ac.shape
    (d_out_a,d_in_a) = Aa.shape

    #setup matrix
    X = np.block([[np.zeros((d_out_a,d_in_c)),Ba,Aa ],
                  [Cc,D,Ca],
                  [Ac,Bc,np.zeros((d_out_c,d_in_a))]
    
    ])

    s_c = np.zeros(X.shape[1],dtype=bool)
    s_r = np.zeros(X.shape[0],dtype=bool)

    #initialize based on the Bs and Cs
    order = np.argsort(np.linalg.norm(Bc,axis=0)-np.linalg.norm(Ba,axis=0))
    s_c[d_in_c+order[:len(order)//2]]=1
    order = np.argsort(np.linalg.norm(Cc,axis=1)-np.linalg.norm(Ca,axis=1))
    s_r[d_out_a+order[:len(order)//2]]=1
    
            
    #set the fixed
    s_c[:d_in_c]=1
    s_c[X.shape[1]-d_in_a:]=0
    s_r[:d_out_a]=0
    s_r[X.shape[0]-d_out_c:]=1

    
    
    fs = np.zeros(N+1)
    s_cols=np.zeros((N+1,X.shape[1]),dtype=bool) 
    s_rows=np.zeros((N+1,X.shape[0]),dtype=bool) 
    
    
    Xs = X**2

    s_cols[0]=s_c
    s_rows[0]=s_r
    fs[0]=np.sum(Xs[s_r][:,s_c])+ np.sum(Xs[~s_r][:,~s_c])
    
    q = int(np.ceil(min(D.shape)/2e2))
    n_restart = -1e5
    
    for n in range(N):
        
        #columns:
        #X_t = X[~s_r]
        #X_b = X[s_r]
        
        #n_xt =np.linalg.norm(X_t,axis=0)
        #n_xb =np.linalg.norm(X_b,axis=0)
        n_xt = np.sum(Xs[~s_r],axis=0)
        n_xb = np.sum(Xs[s_r],axis=0)
        
        #rows:
        #X_r = X[:,~s_c]
        #X_l = X[:,s_c]
        
        #n_xr = np.linalg.norm(X_r,axis=1)
        #n_xl = np.linalg.norm(X_l,axis=1)
        n_xr = np.sum(Xs[:,~s_c],axis=1)
        n_xl = np.sum(Xs[:,s_c],axis=1)
        
        #only the ones we can change
        s_c_int = s_c[d_in_c:X.shape[1]-d_in_a] 
        s_r_int = s_r[d_out_a:X.shape[0]-d_out_c]
        
        S_col = n_xt/np.count_nonzero(~s_r) -n_xb/np.count_nonzero(s_r_int)
        S_row = n_xr/np.count_nonzero(~s_c) -n_xl/np.count_nonzero(s_c_int)
        
        #ord_c = d_in_c +np.argsort(S_col[d_in_c:X.shape[1]-d_in_a])
        #ord_r = d_out_a+np.argsort(S_row[d_out_a:X.shape[0]-d_out_c])
        
        S_col_int = S_col[d_in_c:X.shape[1]-d_in_a] 
        S_row_int = S_row[d_out_a:X.shape[0]-d_out_c]
        
        
        
        if q ==1:
            i_n = -1
            i_p = -1
            v_reg = gamma*(np.count_nonzero(s_c_int)/len(s_c_int)-0.5)**3
            if np.any(S_col_int[s_c_int]<v_reg):
                i_n= np.nonzero(s_c_int)[0][np.argmin(S_col_int[s_c_int])]
                s_c[d_in_c+i_n]=0
            if np.any(S_col_int[~s_c_int]>v_reg):
                i_p= np.nonzero(~s_c_int)[0][np.argmax(S_col_int[~s_c_int])]
                s_c[d_in_c+i_p]=1
            if i_n==-1 and i_p==-1:
                i_n= np.nonzero(s_c_int)[0][np.argmin(S_col_int[s_c_int])]
                s_c[d_in_c+i_n]=0
                i_p= np.nonzero(~s_c_int)[0][np.argmax(S_col_int[~s_c_int])]
                s_c[d_in_c+i_p]=1

            i_n = -1
            i_p = -1
            v_reg = gamma*(np.count_nonzero(s_r_int)/len(s_r_int)-0.5)**3
            if np.any(S_row_int[s_r_int]<v_reg):
                i_n = np.nonzero(s_r_int)[0][np.argmin(S_row_int[s_r_int])]
                s_r[d_out_a+i_n]=0
            if np.any(S_row_int[~s_r_int]>v_reg):
                i_p = np.nonzero(~s_r_int)[0][np.argmax(S_row_int[~s_r_int])]
                s_r[d_out_a+i_p]=1
            if i_n==-1 and i_p==-1:
                i_n = np.nonzero(s_r_int)[0][np.argmin(S_row_int[s_r_int])]
                s_r[d_out_a+i_n]=0
                i_p = np.nonzero(~s_r_int)[0][np.argmax(S_row_int[~s_r_int])]
                s_r[d_out_a+i_p]=1
                
                
        else:
            v_reg = gamma*(np.count_nonzero(s_c_int)/len(s_c_int)-0.5)**3
            
            #if np.any(S_col_int[s_c_int]<v_reg):
            i_n= np.arange(len(s_c_int))[s_c_int][np.argsort(S_col_int[s_c_int])[:q]]#arange ist trick to recover index
            i_nf = i_n[S_col_int[i_n]<v_reg]
            s_c[d_in_c+i_nf]=0
            #if np.any(S_col_int[~s_c_int]>v_reg):
            i_p= np.arange(len(s_c_int))[~s_c_int][np.argsort(S_col_int[~s_c_int])[-q:]]
            i_pf = i_p[S_col_int[i_p]>v_reg]
            s_c[d_in_c+i_pf]=1

            if len(i_nf)==0 and len(i_pf)==0:
                print("flip")
                s_c[d_in_c+i_n[0]]=0
                s_c[d_in_c+i_p[-1]]=1
            
            v_reg = gamma*(np.count_nonzero(s_r_int)/len(s_r_int)-0.5)**3
            #if np.any(S_row_int[s_r_int]<v_reg):
            i_n = np.arange(len(s_r_int))[s_r_int][np.argsort(S_row_int[s_r_int])[:q]]
            i_nf = i_n[S_row_int[i_n]<v_reg]
            s_r[d_out_a+i_nf]=0
            #if np.any(S_row_int[~s_r_int]>v_reg):
            i_p = np.arange(len(s_r_int))[~s_r_int][np.argsort(S_row_int[~s_r_int])[-q:]]
            i_pf = i_p[S_row_int[i_p]>v_reg]
            s_r[d_out_a+i_pf]=1
      
            if len(i_nf)==0 and len(i_pf)==0:
                print("flip")
                s_r[d_out_a+i_n[0]]=0
                s_r[d_out_a+i_p[-1]]=1
        #v_c = ord_c[v_reg_col>S_col[ord_c]]
        #v_r = ord_r[v_reg_row>S_row[ord_r]]
        
        #s_c[d_in_c:s_c.size-d_in_a]=0
        #s_r[d_out_a:s_r.size-d_out_c]=0
        
        #s_c[v_c] = 1
        #s_r[v_r] = 1
        
        f = np.sum(Xs[s_r][:,s_c])+ np.sum(Xs[~s_r][:,~s_c])
        if f > fs[0] and n > n_restart + 50: #worse than initial -> do restart with other initial
            print("restart at n=",n)
            n_restart= n
        #    s_c[d_in_c:X.shape[1]-d_in_a]=np.random.permutation(s_c[d_in_c:X.shape[1]-d_in_a])
        #    s_r[d_out_a:X.shape[0]-d_out_c] = np.random.permutation(s_r[d_out_a:X.shape[0]-d_out_c])
            i_min = np.argmin(fs[:n+1])
            s_c = s_cols[i_min].copy()
            s_r = s_rows[i_min].copy()
            v = np.random.randint(d_in_c,X.shape[1]-d_in_a,3*q)
            s_c[v] = ~s_c[v]
            print(v)
            v = np.random.randint(d_out_a,X.shape[0]-d_out_c,3*q)
            print(v)
            s_r[v] = ~s_r[v]            
        
        fs[n+1] = np.sum(Xs[s_r][:,s_c])+ np.sum(Xs[~s_r][:,~s_c])
        s_cols[n+1]=s_c
        s_rows[n+1]=s_r
        
        if np.any(np.logical_and(np.all(s_cols[:n+1]==s_c,axis=1),np.all(s_rows[:n+1]==s_r,axis=1))):
            print("converged at n=",n)
            break
            
    #get minimum f
    i_min = np.argmin(fs[:n+2])
    s_c = s_cols[i_min]
    s_r = s_rows[i_min]
    print("frac cols:",np.count_nonzero(s_c[d_in_c:X.shape[1]-d_in_a])/D.shape[1])
    print("frac rows:",np.count_nonzero(s_r[d_out_a:X.shape[0]-d_out_c])/D.shape[0])
        
    report ={"s_cols":s_cols[:n+2],"s_rows":s_rows[:n+2],"X":0,"f":fs[:n+2],"q":q}
        
    return s_c[d_in_c:X.shape[1]-d_in_a],s_r[d_out_a:X.shape[0]-d_out_c],report
        


In [None]:
np.random.randint?

In [None]:
gamma =1e2
q =25
sys = Split.initial_mixed(T)
stage_c=sys.causal_system.stages[0]
stage_a=sys.anticausal_system.stages[0]
s_c,s_r,report = segment_matrix(stage_c,stage_a,N=50)
s_cols = report["s_cols"]
s_rows = report["s_rows"]

In [None]:
s_start_c,s_start_a = compute_sigmasf(T,s_cols[0],s_rows[0])
s_end_c,s_end_a = compute_sigmasf(T,s_cols[-1],s_rows[-1])

plt.subplot(1,2,1)
plt.plot(s_start_c,label="start")
plt.plot(s_end_c,label="end")
plt.legend()
plt.subplot(1,2,2)
plt.plot(s_start_a,label="start")
plt.plot(s_end_a,label="end")

In [None]:
print(np.count_nonzero(s_start_a>2)+np.count_nonzero(s_start_c>1.5))
print(np.count_nonzero(s_end_a>2)+np.count_nonzero(s_end_c>1.5))

In [None]:
plt.spy(s_cols)
print(np.count_nonzero(s_cols[-1])/len(s_cols[-1]))

In [None]:
plt.spy(s_rows)
print(np.count_nonzero(s_cols[-1])/len(s_cols[-1]))

In [None]:
plt.plot(report["f"])

In [None]:
np.all(s_cols[0]==s_cols[1])

## Some helping function to apply permutation and store it

In [None]:
def get_permutations(s_c,s_r):
    p_col = np.hstack([np.nonzero(s_c),np.nonzero(~s_c)]).reshape(-1)
    p_row = np.hstack([np.nonzero(~s_r),np.nonzero(s_r)]).reshape(-1)
    i_in =np.count_nonzero(s_c)
    i_out=np.count_nonzero(s_r)
    return p_col,p_row,i_in,i_out

def permute_stage(stage,p_col,p_row):
    stage.B_tilde = stage.B_tilde[:,p_col]
    stage.C_tilde = stage.C_tilde[p_row,:]
    stage.D_matrix = stage.D_matrix[:,p_col][p_row]
    
def collect_permutations(P_col,P_row,k,p_col,p_row,system):
    """
    Function to collect the permutations in P_col and P_row
    
    P_col:    total permutation of columns
    P_row:    total permutation of columns
    l:        index of stage
    p_col:    new collumn permuation
    p_row:    new row permutation
    """
    
    dims_in = system.dims_in
    dims_out = system.dims_out
    
    I = np.sum(dims_in[:k]).astype(int)
    P_col[I:I+dims_in[k]]=P_col[I:I+dims_in[k]][p_col]
    I = np.sum(dims_out[:k]).astype(int)
    P_row[I:I+dims_out[k]]=P_row[I:I+dims_out[k]][p_row]

## Get test matrix

In [None]:
dims_in =  np.array([4, 4, 4, 4])*6
dims_out = np.array([4, 4, 4, 4])*6

#dims_in =  np.array([9, 7, 7, 9])*3
#dims_out = np.array([7, 9, 9, 7])*3


n = 2
#create orthogonal vectors and normalize them to the size of the matix (i.e. norm(block)/size(block) = const
Us =np.vstack([scipy.stats.ortho_group.rvs(dims_out[i])[:,:3*n]*dims_out[i] for i in range(len(dims_in))])
Vts=np.hstack([scipy.stats.ortho_group.rvs(dims_in[i])[:3*n,:]*dims_in[i] for i in range(len(dims_in))])

s = np.linspace(1,0.75,n)

lower = Us[:,:n]@np.diag(s)@Vts[:n,:]
diag = Us[:,n:2*n]@np.diag(s)@Vts[n:2*n,:]
upper = Us[:,2*n:3*n]@np.diag(s)@Vts[2*n:3*n,:]
matrix = np.zeros_like(diag)
a=0;b=0
for i in range(len(dims_in)):
    matrix[a:a+dims_out[i],:b]            =lower[a:a+dims_out[i],:b]
    matrix[a:a+dims_out[i],b:b+dims_in[i]]=diag[a:a+dims_out[i],b:b+dims_in[i]]
    matrix[a:a+dims_out[i],b+dims_in[i]:] =upper[a:a+dims_out[i],b+dims_in[i]:]
    a+=dims_out[i];b+=dims_in[i]
plt.figure()

P_in_ref = np.random.permutation(np.arange(matrix.shape[1]))
P_out_ref= np.random.permutation(np.arange(matrix.shape[0]))

T = matrix[P_out_ref][:,P_in_ref]
plt.matshow(T)
print(T.shape)

In [None]:
#T = np.random.rand(32,32)
#T =np.arange(1,32).reshape(1,-1)*np.arange(1,32).reshape(-1,1)
#T = matrix.T
#T = mats = get_mobilenet_target_mats()[0]
sys = Split.initial_sigmas_mixed(T)
utils.show_system(sys,mark_D=False)
utils.check_dims(sys)

## combine it to algorithm

In [None]:
def identification_split_system(sys,N):
    
    P_col = np.arange(np.sum(sys.dims_in) ,dtype=int)
    P_row = np.arange(np.sum(sys.dims_out),dtype=int)
    
    Ps_col =np.zeros((N,P_col.size),dtype=int)
    Ps_row =np.zeros((N,P_row.size),dtype=int)
    for n in range(N):
        print(n)
        for k in range(len(sys.causal_system.stages)-1,-1,-1): #reverse ordering makes indexing easier 
            i_in =sys.causal_system.stages[k].dim_in//2
            i_out=sys.causal_system.stages[k].dim_out//2
            Split.split_sigmas_mixed(sys,k,i_in,i_out)
    

In [None]:
sys = Split.initial_sigmas_mixed(T)
identification_split_system(sys,2)
utils.check_dims(sys)
utils.show_system(sys)

In [None]:
sigmas_causal =[stage.s_in for stage in sys.causal_system.stages][1:]
sigmas_anticausal =[stage.s_in for stage in sys.anticausal_system.stages][:-1]
#print(sigmas_causal)
#print(sigmas_anticausal)
plt.subplot(1,2,1)
for sig in sigmas_causal:
    plt.scatter(np.arange(len(sig)),sig)
plt.subplot(1,2,2)
for sig in sigmas_anticausal:
    plt.scatter(np.arange(len(sig)),sig)

In [None]:
def identification_split_clustering(sys,N,N_split = 50):
    
    P_col = np.arange(np.sum(sys.dims_in) ,dtype=int)
    P_row = np.arange(np.sum(sys.dims_out),dtype=int)
    
    Ps_col =np.zeros((N+1,P_col.size),dtype=int)
    Ps_row =np.zeros((N+1,P_row.size),dtype=int)
    Ps_col[0]=P_col
    Ps_row[0]=P_row
    reports = []
    for n in range(N):
        print(n)
        for k in range(len(sys.causal_system.stages)-1,-1,-1): #reverse ordering makes indexing easier 
            stage_c=sys.causal_system.stages[k]
            stage_a=sys.anticausal_system.stages[k]
            s_c,s_r,report = segment_matrix(stage_c,stage_a,N=N_split)
            reports.append(report)
            
            assert len(s_c)==stage_c.dim_in ,"dims_in causal do not match s_c"
            assert len(s_r)==stage_c.dim_out,"dims_out causal do not match s_r"
            assert len(s_c)==stage_a.dim_in ,"dims_in antic do not match s_c"
            assert len(s_r)==stage_a.dim_out,"dims_out antic do not match s_r"
            p_col,p_row,i_in,i_out = get_permutations(s_c,s_r)
            permute_stage(stage_c,p_col,p_row)
            permute_stage(stage_a,p_col,p_row)
            collect_permutations(P_col,P_row,k,p_col,p_row,sys)
            
            Split.split_sigmas_mixed(sys,k,i_in,i_out)
        #save the Permutations collected for all stages
        Ps_col[n+1]=P_col
        Ps_row[n+1]=P_row
    
    return Ps_col,Ps_row,reports

In [None]:
gamma = 1e2
sys = Split.initial_sigmas_mixed(T)
Ps_col,Ps_row,reports = identification_split_clustering(sys,2)
P_col = Ps_col[-1]
P_row = Ps_row[-1]
utils.check_dims(sys)
utils.show_system(sys)

In [None]:
plt.subplot(1,2,1)
plt.scatter(np.arange(len(P_col)),P_col)
plt.subplot(1,2,2)
plt.scatter(np.arange(len(P_row)),P_row)

In [None]:
sigmas_causal =[stage.s_in for stage in sys.causal_system.stages][1:]
sigmas_anticausal =[stage.s_in for stage in sys.anticausal_system.stages][:-1]
#print(sigmas_causal)
#print(sigmas_anticausal)
plt.subplot(1,2,1)
for i,sig in enumerate(sigmas_causal):
    plt.scatter(np.arange(len(sig)),sig,label=str(i))
plt.legend()
plt.subplot(1,2,2)
for i,sig in enumerate(sigmas_anticausal):
    plt.scatter(np.arange(len(sig)),sig,label=str(i))
plt.legend()

In [None]:
plt.matshow(T[P_row][:,P_col])

In [None]:
plt.matshow(T[Ps_row[-1]][:,Ps_col[-1]])

In [None]:
utils.check_dims(sys)

### TODO: Some assymetry here

The algorithm usually recvers the row clustering, but is unable to recover the column clustering. 
This is reagerdless of wheater the matrix is transposed or not

In [None]:
cmap = plt.cm.get_cmap('tab20')
colors = np.repeat(cmap((1/20)*np.arange(4)+0.001),dims_in,axis=0)[P_in_ref]
perm.multiple_connection_plot(perm.invert_permutations(Ps_col),colors=colors,start=0,end=2)

In [None]:
cmap = plt.cm.get_cmap('tab20')
colors = np.repeat(cmap((1/20)*np.arange(4)+0.001),dims_in,axis=0)[P_in_ref]
perm.multiple_connection_plot(perm.invert_permutations(Ps_row),colors=colors,start=0,end=2)

# Weight matrix form Mobilenet

In [None]:
T = get_mobilenet_target_mats()[0]

In [None]:
sys = Split.initial_sigmas_mixed(T)
identification_split_system(sys,3)
utils.check_dims(sys)
utils.show_system(sys)

In [None]:
sigmas_causal =[stage.s_in for stage in sys.causal_system.stages][1:]
sigmas_anticausal =[stage.s_in for stage in sys.anticausal_system.stages][:-1]
#print(sigmas_causal)
#print(sigmas_anticausal)
plt.subplot(1,2,1)
for sig in sigmas_causal:
    plt.scatter(np.arange(len(sig)),sig)
plt.subplot(1,2,2)
for sig in sigmas_anticausal:
    plt.scatter(np.arange(len(sig)),sig)

In [None]:
gamma = 1e1
#q = 20
sys_per = Split.initial_sigmas_mixed(T)
Ps_col,Ps_row,reports = identification_split_clustering(sys_per,3,N_split=200)
utils.check_dims(sys_per)
utils.show_system(sys_per)

In [None]:
sigmas_causal =[stage.s_in for stage in sys_per.causal_system.stages][1:]
sigmas_anticausal =[stage.s_in for stage in sys_per.anticausal_system.stages][:-1]
#print(sigmas_causal)
#print(sigmas_anticausal)
plt.subplot(1,2,1)
for sig in sigmas_causal:
    plt.scatter(np.arange(len(sig)),sig)
plt.subplot(1,2,2)
for sig in sigmas_anticausal:
    plt.scatter(np.arange(len(sig)),sig)

In [None]:
sigmas_causal_per =[stage.s_in for stage in sys_per.causal_system.stages][1:]
sigmas_anticausal_per =[stage.s_in for stage in sys_per.anticausal_system.stages][:-1]

sigmas_causal =[stage.s_in for stage in sys.causal_system.stages][1:]
sigmas_anticausal =[stage.s_in for stage in sys.anticausal_system.stages][:-1]

plt.figure(figsize=[12,8])

plt.subplot(1,2,1)
plt.grid()
for sig in sigmas_causal:
    plt.plot(np.arange(len(sig)),sig,color='C0')
for sig in sigmas_causal_per:
    plt.plot(np.arange(len(sig)),sig,color='C1')
plt.subplot(1,2,2)
for sig in sigmas_anticausal:
    plt.plot(np.arange(len(sig)),sig,color='C0')
for sig in sigmas_anticausal_per:
    plt.plot(np.arange(len(sig)),sig,color='C1')
plt.grid()


In [None]:
T_per = T[Ps_row[-1]][:,Ps_col[-1]]
np.max(np.abs(T_per-sys_per.to_matrix()))

In [None]:
eps_max = max([np.max(sig)for sig in sigmas_causal]+[np.max(sig)for sig in sigmas_anticausal])

In [None]:
approx =Approximation(sys,(sigmas_causal,sigmas_anticausal))
approx_per=Approximation(sys_per,(sigmas_causal_per,sigmas_anticausal_per))


N = 9 #number of points
alpha = np.linspace(0,1,N)

err_move =np.zeros_like(alpha)

eps = eps_max*alpha

def calc_values(approx,eps,matrix):
    costs =np.zeros_like(eps)
    err =np.zeros_like(eps)
    for i in range(len(eps)):
        approx_system=approx.get_approxiamtion(eps[i])
        matrix_approx = approx_system.to_matrix()
        err[i] = np.linalg.norm(matrix_approx-matrix,ord=2)
        costs[i] = approx_system.cost()
    return err,costs

err_orig,cost_orig = calc_values(approx,eps,T)
err_per,cost_per = calc_values(approx_per,eps,T_per)

In [None]:
plt.plot(cost_orig,err_orig,label="orig")
plt.plot(cost_per,err_per,label="per")
plt.legend()
plt.grid()

In [None]:
alpha[2]

In [None]:
err_orig[2]

In [None]:
err_per[2]

In [None]:
cost_orig[2]

In [None]:
cost_per[2]

In [None]:
i = 2
print(alpha[i])
print(1-cost_per[i]/cost_orig[i])

In [None]:
print(err_per[i]/err_orig[i])

In [None]:
plt.plot(alpha,cost_per/cost_orig)

In [None]:
eps[2]

In [None]:
for report in reports:
    plt.figure()
    plt.plot(report["f"])
#    plt.figure()
#    plt.spy(report["s_cols"])
#    plt.figure()
#    plt.spy(report["s_rows"])

In [None]:
plt.spy(reports[4]["s_rows"])

In [None]:
plt.spy(reports[4]["s_cols"])

In [None]:
reports[4]["q"]

In [None]:
reports[4]["s_cols"][3]

In [None]:
reports[4]["s_cols"].dtype

In [None]:
def get_AlexNet_target_mats():
    target_mats = []
    # Load the model
    model = models.alexnet(pretrained=True)
    # Put moel into eval mode
    model.eval()
    for layer in model.classifier:
        if isinstance(layer, torch.nn.Linear):
            # Obtain the weights of this layer
            weights = layer.weight.detach().numpy()
            target_mats.append(weights)
    return target_mats
mat_AlexNet = get_AlexNet_target_mats()[0]

In [None]:
T = mat_AlexNet