# Clustering to reduce $\|H\|_F$

The idea is to cluster collumns and rows

Here we have the objective function 



In [None]:
import numpy as np
import tvsclib.utils as utils
import Split
import matplotlib.pyplot as plt
from tvsclib.strict_system import StrictSystem

from tvsclib.approximation import Approximation

import torchvision.models as models
import torch
import scipy.stats 

import graphs

import scipy.linalg as linalg

import plot_permutations as perm
import setup_plots
from mpl_toolkits.axes_grid1 import make_axes_locatable

import Split
import Split_permute as SplitP

In [None]:
setup_plots.setup()
plt.rcParams['figure.dpi'] = 150

In [None]:
def get_mobilenet_target_mats():
    target_mats = []
    # Load the model
    model = models.mobilenet_v2(pretrained=True)
    # Put moel into eval mode
    model.eval()
    for layer in model.classifier:
        if isinstance(layer, torch.nn.Linear):
            # Obtain the weights of this layer
            weights = layer.weight.detach().numpy()
            target_mats.append(weights)
    return target_mats

# Weight matrix form Mobilenet

In [None]:
T = get_mobilenet_target_mats()[0]

In [None]:
sys = Split.identification_split_system(T,3)
utils.check_dims(sys)
utils.show_system(sys)
sys.dims_out

In [None]:
sys_per,Ps_col,Ps_row,reports = SplitP.identification_split_permute(T,3,strategy="fro",\
            opts=[{"gamma":9e5,"N":200},{"gamma":9e5,"N":200},{"gamma":9e5,"N":200}])
utils.check_dims(sys_per)
utils.show_system(sys_per)

In [None]:
w = setup_plots.textwidth
fig, ax = plt.subplots(figsize=(w, w))

utils.show_system(sys_per,ax=ax)
y_lim = ax.get_ylim()
x_lim = ax.get_xlim()
ax.xaxis.set_ticks_position('top')

divider = make_axes_locatable(ax)
ax_dimsin = divider.append_axes("top", 1.1, pad=0.1, sharex=ax)
ax_dimsout = divider.append_axes("left", 1.1, pad=0.1, sharey=ax)

# make some labels invisible
ax_dimsin.xaxis.set_tick_params(labelbottom=False)
ax_dimsout.yaxis.set_tick_params(labelright=False)


ax_dimsin.invert_yaxis()

cmap = plt.cm.get_cmap("rainbow")
colors = cmap(np.linspace(0,0.9,Ps_col.shape[1]))
perm.multiple_connection_plot(perm.invert_permutations(Ps_col),start=0,end=3,ax=ax_dimsin,N=20,flipxy=True,linewidth=0.25,colors=colors)
colors = cmap(np.linspace(0,0.9,Ps_row.shape[1]))
perm.multiple_connection_plot(perm.invert_permutations(Ps_row),start=0,end=3,ax=ax_dimsout,N=20,linewidth=0.2,colors=colors)


ax_dimsout.xaxis.set_ticks_position('top')
ax_dimsout.yaxis.set_ticks_position('right')
ax_dimsout.yaxis.set_tick_params(labelright=False)

ax_dimsin.set_yticks(np.arange(1,4))
ax_dimsout.set_xticks(np.arange(1,4))

ax_dimsin.set_xticks(np.cumsum(sys_per.dims_in))
ax_dimsout.set_yticks(np.cumsum(sys_per.dims_out))
ax.set_xticklabels([])
ax.set_yticklabels([])

ax_dimsin.grid()
ax_dimsout.grid()
ax_dimsout.set_xlim((0,3))
ax_dimsin.set_ylim((3,0))  
ax.set_ylim(y_lim)
ax.set_xlim(x_lim)

ax_dimsin.set_yticks(np.arange(0,4))
ax_dimsout.set_xticks(np.arange(0,4))
ax_dimsin.set_yticklabels(["$0$","$1$","$2$",r""],zorder=0)
ax_dimsout.set_xticklabels(["$0$","$1$","$1$",r"$3\,.$"],zorder=0) 

ax.text(0,0,r'Iteration$\qquad\quad$.',rotation=-45,\
                 horizontalalignment='right', verticalalignment='center',rotation_mode='anchor')

plt.savefig("Mobilenet_permute.pdf",bbox_inches = 'tight',bbox="tight")
bbox = fig.get_tightbbox(fig.canvas.get_renderer()) 

In [None]:
bbox.width/w

In [None]:
sigmas_causal_per =[stage.s_in for stage in sys_per.causal_system.stages][1:]
sigmas_anticausal_per =[stage.s_in for stage in sys_per.anticausal_system.stages][:-1]

sigmas_causal =[stage.s_in for stage in sys.causal_system.stages][1:]
sigmas_anticausal =[stage.s_in for stage in sys.anticausal_system.stages][:-1]

plt.figure(figsize=[12,8])

plt.subplot(1,2,1)
plt.grid()
for sig in sigmas_causal:
    plt.plot(np.arange(len(sig)),sig,color='C0')
for sig in sigmas_causal_per:
    plt.plot(np.arange(len(sig)),sig,color='C1')
plt.subplot(1,2,2)
for sig in sigmas_anticausal:
    plt.plot(np.arange(len(sig)),sig,color='C0')
for sig in sigmas_anticausal_per:
    plt.plot(np.arange(len(sig)),sig,color='C1')
plt.grid()


In [None]:
T_per = T[Ps_row[-1]][:,Ps_col[-1]]
np.max(np.abs(T_per-sys_per.to_matrix()))

In [None]:
eps_max = max([np.max(sig)for sig in sigmas_causal]+[np.max(sig)for sig in sigmas_anticausal])
print(eps_max)

In [None]:
approx =Approximation(sys,(sigmas_causal,sigmas_anticausal))
approx_per=Approximation(sys_per,(sigmas_causal_per,sigmas_anticausal_per))


N = 9 #number of points
#N = 18 #number of points
alpha = np.linspace(0,1,N)

err_move =np.zeros_like(alpha)

eps = eps_max*alpha

def calc_values(approx,eps,matrix):
    costs =np.zeros_like(eps)
    err =np.zeros_like(eps)
    for i in range(len(eps)):
        approx_system=approx.get_approxiamtion(eps[i])
        matrix_approx = approx_system.to_matrix()
        err[i] = np.linalg.norm(matrix_approx-matrix,ord=2)
        costs[i] = approx_system.cost()
    return err,costs

err_orig,cost_orig = calc_values(approx,eps,T)
err_per,cost_per = calc_values(approx_per,eps,T_per)

In [None]:
w = 0.75*setup_plots.textwidth
fig, ax = plt.subplots(figsize=(w, 2/3*w))
plt.plot(cost_orig,err_orig,'1-',label='Regular system')
plt.plot(cost_per,err_per,'2--',label='Permuted system')
ylims = ax.get_ylim()
plt.vlines(T.size,ylims[0],ylims[1],colors='0.4')
ax.ticklabel_format(axis='x',scilimits=(0,0))
ax.set_ylim(ylims)
plt.grid()
plt.legend()

i = np.argmin(np.abs(alpha-0.25))
#plt.scatter([cost_orig[i],cost_move[i]],[err_orig[i],err_move[i]])


plt.xlabel("Number of multiplications")
plt.ylabel(r'$\| M-\tilde{T} \| $')

zoom_h = 1
zoom_w = 3e5
s = 5.2
axins = ax.inset_axes([0.45, 0.2,s*zoom_w/4e6,s*zoom_h/10])

axins.grid()
axins.plot(cost_orig,err_orig,'1-')
axins.plot(cost_per,err_per,'2--')

axins.set_xlim(cost_orig[i]-0.5*zoom_w, cost_orig[i]+0.5*zoom_w)
axins.set_ylim(err_orig[i]-0.5*zoom_h, err_orig[i]+0.5*zoom_h)
axins.set_xticklabels([])
axins.set_yticklabels([])


text = ax.text(cost_orig[0]-3e5,err_orig[0]+0.7, r'$\epsilon = 0$',
                  bbox={'facecolor': 'white',"edgecolor":"black", 'alpha': 0.5, 'pad': 0,"linewidth":0})
text = ax.text(cost_orig[-1]+1.1e5,err_orig[-1]-.5, r'$\epsilon = \|M\|_H$',
                  bbox={'facecolor': 'white',"edgecolor":"black", 'alpha': 0.5, 'pad': 0,"linewidth":0})


ax.indicate_inset_zoom(axins, edgecolor="black")
axins.text(cost_orig[i]-0.15*zoom_w,err_orig[i]+0.3*zoom_h, r'$\epsilon = \frac{1}{4} \|M\|_H$',
          bbox={'facecolor': 'white',"edgecolor":"black", 'alpha': 0.5, 'pad': 0,"linewidth":0})
plt.savefig("perm_example_mobilenet_error.pdf",bbox="tight",bbox_inches = 'tight')
bbox = plt.gcf().get_tightbbox( plt.gcf().canvas.get_renderer()) 
print(bbox.width/setup_plots.textwidth)

In [None]:
i = 2
print("alpha=",alpha[i])
print("eps=",eps[i])
print("Cost original=",cost_orig[i])
print("Cost new=",cost_per[i])
print("Cost new/Cost orig=",cost_per[i]/cost_orig[i])

In [None]:
n = 2
print("eps * K = ",eps[n]*len(sys.dims_in))
print("err =",err_orig[n])

In [None]:
plt.spy(reports[0]["s_rows"])

In [None]:
sys_apr_quater = approx.get_approxiamtion(eps[i])
sys_apr_quater_per = approx_per.get_approxiamtion(eps[i])
print(eps[i])


dims_state_ref = [sys_apr_quater.causal_system.dims_state,sys_apr_quater.anticausal_system.dims_state]
dims_state = [sys_apr_quater_per.causal_system.dims_state,sys_apr_quater_per.anticausal_system.dims_state]

x = np.arange(len(dims_state[0]))  # the label locations
width = 0.35  # the width of the bars

w = setup_plots.textwidth
fig, axes = plt.subplots(2,1,figsize=(w, w/2),sharex=True)
for v in [0,1]:
    ax =axes[v]
    if v ==0:
        label = ['Regular',None]
        ax.set_ylabel(r"$d$")
    else:
        label = [None,'Permuted']
        ax.set_ylabel(r"$d^*$")
    rects1 = ax.bar(x - width/2, dims_state_ref[v], width, label=label[0])
    rects2 = ax.bar(x + width/2, dims_state[v], width, label=label[1])
    ax.grid()
    ax.set_ylim(0,105)
    if v ==0:
        ax.bar(0,0,label=r"$\Sigma ="+str(np.sum(dims_state_ref))+"$",color = "w")
        ax.bar(0,0,label=" ",color = "w")
    else:
        ax.bar(0,0,label=r"$\Sigma ="+str(np.sum(dims_state))+"$",color = "w")#linewidth=0    

ax.set_xlim(0.4,7.6)
ax.set_xlabel(r"$k$")

fig.legend(loc='center right')
plt.subplots_adjust(right=0.75)

plt.savefig("perm_example_mobilenet_state_dims.pdf",bbox="tight",bbox_inches = 'tight')
bbox = plt.gcf().get_tightbbox( plt.gcf().canvas.get_renderer()) 
print(bbox.width/setup_plots.textwidth)

In [None]:
dim_sum = np.sum(dims_state,axis=0)
dim_sum_ref = np.sum(dims_state_ref,axis=0)
print(dim_sum,np.sum(dim_sum))
print(dim_sum_ref,np.sum(dim_sum_ref))

In [None]:
plt.figure(figsize = (4,2*len(reports)))
for i,report in enumerate(reports):
    plt.subplot(len(reports),1,i+1)
    plt.plot(report["f"])


## Average values of Blocks

In [None]:
# Elemetwise square:
M = T_per**2

dims_in_cum = np.hstack(([0],np.cumsum(sys_per.dims_in)))
dims_out_cum = np.hstack(([0],np.cumsum(sys_per.dims_out)))
K = len(sys_per.dims_in)
for k in range(K):
    for l in range(K):
        v = np.sum(M[dims_out_cum[k]:dims_out_cum[k+1],dims_in_cum[l]:dims_in_cum[l+1]])/M[dims_out_cum[k]:dims_out_cum[k+1],dims_in_cum[l]:dims_in_cum[l+1]].size
        #print(v)
        M[dims_out_cum[k]:dims_out_cum[k+1],dims_in_cum[l]:dims_in_cum[l+1]]=v
                
                         
plt.matshow(M)
plt.colorbar()

In [None]:
np.sum(T**2)/T.size

In [None]:
frob_causal = []
frob_anticausal = []
K = len(sys.dims_in)
for k in range(1,K):
    frob_causal.append(np.linalg.norm(T[np.sum(sys.dims_out)-np.sum(sys.dims_out[k:]):,:np.sum(sys.dims_in[:k])]))
for k in range(0,K-1):
    frob_anticausal.append(np.linalg.norm(T[:np.sum(sys.dims_out[:k+1]),np.sum(sys.dims_in)-np.sum(sys.dims_in[k+1:]):]))

frob_causal_per = []
frob_anticausal_per = []
K = len(sys.dims_in)
for k in range(1,K):
    frob_causal_per.append(np.linalg.norm(T_per[np.sum(sys_per.dims_out)-np.sum(sys_per.dims_out[k:]):,:np.sum(sys_per.dims_in[:k])]))
for k in range(0,K-1):
    frob_anticausal_per.append(np.linalg.norm(T_per[:np.sum(sys_per.dims_out[:k+1]),np.sum(sys_per.dims_in)-np.sum(sys_per.dims_in[k+1:]):]))


In [None]:
plt.subplot(2,1,1)
plt.plot(frob_causal)
plt.plot(frob_causal_per)
plt.subplot(2,1,2)
plt.plot(frob_anticausal)
plt.plot(frob_anticausal_per)

# Alexnet

In [None]:
def get_AlexNet_target_mats():
    target_mats = []
    # Load the model
    model = models.alexnet(pretrained=True)
    # Put moel into eval mode
    model.eval()
    for layer in model.classifier:
        if isinstance(layer, torch.nn.Linear):
            # Obtain the weights of this layer
            weights = layer.weight.detach().numpy()
            target_mats.append(weights)
    return target_mats
mat_AlexNet = get_AlexNet_target_mats()[0]

In [None]:
T = mat_AlexNet

In [None]:
sys = Split.identification_split_system(T,4,epsilon=1e-3)
utils.check_dims(sys)

In [None]:
sys_per,Ps_col,Ps_row,reports = SplitP.identification_split_permute(T,4,epsilon=1e-3,strategy="fro",\
            opts=[{"gamma":6e3,"N":100},{"gamma":6e3,"N":100},{"gamma":6e3,"N":100},{"gamma":6e3,"N":100}])
utils.check_dims(sys_per)

In [None]:
utils.show_system(sys_per)

In [None]:
sigmas_causal_per =[stage.s_in for stage in sys_per.causal_system.stages][1:]
sigmas_anticausal_per =[stage.s_in for stage in sys_per.anticausal_system.stages][:-1]

sigmas_causal =[stage.s_in for stage in sys.causal_system.stages][1:]
sigmas_anticausal =[stage.s_in for stage in sys.anticausal_system.stages][:-1]

plt.figure(figsize=[12,8])

plt.subplot(1,2,1)
plt.grid()
for sig in sigmas_causal:
    plt.plot(np.arange(len(sig)),sig,color='C0')
for sig in sigmas_causal_per:
    plt.plot(np.arange(len(sig)),sig,color='C1')
plt.subplot(1,2,2)
for sig in sigmas_anticausal:
    plt.plot(np.arange(len(sig)),sig,color='C0')
for sig in sigmas_anticausal_per:
    plt.plot(np.arange(len(sig)),sig,color='C1')
plt.grid()


In [None]:
T_per = T[Ps_row[-1]][:,Ps_col[-1]]
np.max(np.abs(T_per-sys_per.to_matrix()))

In [None]:
eps_max = max([np.max(sig)for sig in sigmas_causal]+[np.max(sig)for sig in sigmas_anticausal])
print(eps_max)

In [None]:
approx =Approximation(sys,(sigmas_causal,sigmas_anticausal))
approx_per=Approximation(sys_per,(sigmas_causal_per,sigmas_anticausal_per))


N = 9 #number of points
#N = 18 #number of points
alpha = np.linspace(0,1,N)

err_move =np.zeros_like(alpha)

eps = eps_max*alpha

def calc_values(approx,eps,matrix):
    costs =np.zeros_like(eps)
    err =np.zeros_like(eps)
    for i in range(len(eps)):
        approx_system=approx.get_approxiamtion(eps[i])
        matrix_approx = approx_system.to_matrix()
        err[i] = np.linalg.norm(matrix_approx-matrix,ord=2)
        costs[i] = approx_system.cost()
    return err,costs

err_orig,cost_orig = calc_values(approx,eps,T)
err_per,cost_per = calc_values(approx_per,eps,T_per)

In [None]:
plt.plot(cost_orig,err_orig,label="orig")
plt.plot(cost_per,err_per,label="per")
plt.legend()
plt.grid()

In [None]:
i = 2
print("alpha=",alpha[i])
print("eps=",eps[i])
print("Cost original=",cost_orig[i])
print("Cost new=",cost_per[i])
print("Cost new/Cost orig=",cost_per[i]/cost_orig[i])

In [None]:
print(eps[2])

In [None]:
cost_orig[2]

In [None]:
cost_per[2]

In [None]:
i = 2
print(alpha[i])
print((1-cost_per[i]/cost_orig[i]),"%")

In [None]:
plt.plot(alpha,cost_per/cost_orig)
plt.grid()

In [None]:
eps[2]

In [None]:
plt.figure(figsize = (4,2*len(reports)))
for i,report in enumerate(reports):
    plt.subplot(len(reports),1,i+1)
    plt.plot(report["f"])


In [None]:
print(sys.dims_in)
print(sys_per.dims_in)

In [None]:
print(sys.dims_out)
print(sys_per.dims_out)

In [None]:
sum(np.eye(5))

In [None]:
np.array([[sum(np.eye(5))]]).shape

In [None]:
len(sys.dims_in)

In [None]:
utils.save_system(sys,'AlexNet/system_ref_perm.npz',sigmas=(sigmas_causal,sigmas_anticausal))
utils.save_system(sys_per,'AlexNet/system_perm.npz',sigmas=(sigmas_causal_per,sigmas_anticausal_per))

#approx =Approximation(sys,(sigmas_causal,sigmas_anticausal))
#approx_per=Approximation(sys_per,(sigmas_causal_per,sigmas_anticausal_per))

np.savez('AlexNet/data_per.npz',err_orig=err_orig,cost_orig=cost_orig,err_per=err_per,\
            cost_per=cost_per,\
            Ps_col=Ps_col,Ps_row=Ps_row)
        
#err_orig,cost_orig = calc_values(approx,eps,T)
#err_per,cost_per = calc_values(approx_per,eps,T_per)

In [None]:
P