# Clustering collums and rows

The idea is to cluster collumns and rows

Here we have the objective function 



In [None]:
import numpy as np
import tvsclib.utils as utils
import Split
import matplotlib.pyplot as plt
from tvsclib.strict_system import StrictSystem

from tvsclib.approximation import Approximation

import torchvision.models as models
import torch
import scipy.stats 

import scipy.linalg as linalg
import plot_permutations as perm

import setup_plots
import Split_permute as SplitP

In [None]:
setup_plots.setup()
plt.rcParams['figure.dpi'] = 150

In [None]:
np.random.seed(10000)

## Get test matrix

In [None]:
dims_in =  np.array([4, 4, 4, 4])*4
dims_out = np.array([4, 4, 4, 4])*4

#dims_in =  np.array([4, 5, 5, 4])*3
#dims_out = np.array([5, 4, 4, 5])*3

n = 2
#create orthogonal vectors and normalize them to the size of the matix (i.e. norm(block)/size(block) = const
Us =np.vstack([scipy.stats.ortho_group.rvs(dims_out[i])[:,:3*n]*dims_out[i] for i in range(len(dims_in))])
Vts=np.hstack([scipy.stats.ortho_group.rvs(dims_in[i])[:3*n,:]*dims_in[i] for i in range(len(dims_in))])

s = np.linspace(1,0.75,n)

lower = Us[:,:n]@np.diag(s)@Vts[:n,:]
diag = Us[:,n:2*n]@np.diag(s)@Vts[n:2*n,:]
upper = Us[:,2*n:3*n]@np.diag(s)@Vts[2*n:3*n,:]
matrix = np.zeros_like(diag)
a=0;b=0
for i in range(len(dims_in)):
    matrix[a:a+dims_out[i],:b]            =lower[a:a+dims_out[i],:b]
    matrix[a:a+dims_out[i],b:b+dims_in[i]]=diag[a:a+dims_out[i],b:b+dims_in[i]]
    matrix[a:a+dims_out[i],b+dims_in[i]:] =upper[a:a+dims_out[i],b+dims_in[i]:]
    a+=dims_out[i];b+=dims_in[i]
plt.figure()

P_in_ref = np.random.permutation(np.arange(matrix.shape[1]))
P_out_ref= np.random.permutation(np.arange(matrix.shape[0]))

T = matrix[P_out_ref][:,P_in_ref]
plt.matshow(T)
print(T.shape)

In [None]:
sys,Ps_col,Ps_row,reports = SplitP.identification_split_permute(T,2,strategy="rank",opts={"gamma":1e5})
P_col = Ps_col[-1]
P_row = Ps_row[-1]
utils.check_dims(sys)
utils.show_system(sys)

In [None]:
plt.subplot(1,2,1)
plt.scatter(np.arange(len(P_col)),P_col)
plt.subplot(1,2,2)
plt.scatter(np.arange(len(P_row)),P_row)

In [None]:
sigmas_causal =[stage.s_in for stage in sys.causal_system.stages][1:]
sigmas_anticausal =[stage.s_in for stage in sys.anticausal_system.stages][:-1]
#print(sigmas_causal)
#print(sigmas_anticausal)
plt.subplot(1,2,1)
for i,sig in enumerate(sigmas_causal):
    plt.scatter(np.arange(len(sig)),sig,label=str(i))
plt.legend()
plt.subplot(1,2,2)
for i,sig in enumerate(sigmas_anticausal):
    plt.scatter(np.arange(len(sig)),sig,label=str(i))
plt.legend()

In [None]:
utils.check_dims(sys)
np.max(np.abs(T[P_row][:,P_col]-sys.to_matrix()))

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

w = setup_plots.textwidth*0.7
fig, ax = plt.subplots(figsize=(w, w))

utils.show_system(sys,ax=ax)
y_lim = ax.get_ylim()
x_lim = ax.get_xlim()
ax.xaxis.set_ticks_position('top')

divider = make_axes_locatable(ax)
ax_dimsin = divider.append_axes("top", 0.9, pad=0.045, sharex=ax)
ax_dimsout = divider.append_axes("left", 0.9, pad=0.045, sharey=ax)

# make some labels invisible
ax_dimsin.xaxis.set_tick_params(labelbottom=False)
ax_dimsout.yaxis.set_tick_params(labelright=False)


ax_dimsin.invert_yaxis()

cmap = plt.cm.get_cmap('tab20')
colors = np.repeat(cmap((1/20)*np.arange(4)+0.001),dims_in,axis=0)[P_in_ref]
perm.multiple_connection_plot(perm.invert_permutations(Ps_col),colors=colors,start=0,end=2,ax=ax_dimsin,flipxy=True,linewidth=2.0)

colors = np.repeat(cmap((1/20)*np.arange(4)+0.001),dims_out,axis=0)[P_out_ref]
perm.multiple_connection_plot(perm.invert_permutations(Ps_row),colors=colors,start=0,end=2,ax=ax_dimsout,linewidth=2.0)


ax_dimsout.xaxis.set_ticks_position('top')
ax_dimsout.yaxis.set_ticks_position('right')
ax_dimsout.yaxis.set_tick_params(labelright=False)


ax_dimsin.set_xticks(np.arange(0,4**3+1,16)-0.5)
ax_dimsout.set_yticks(np.arange(0,4**3+1,16)-0.5)
ax.set_xticklabels([])
ax.set_yticklabels([])

ax_dimsin.grid()
ax_dimsout.grid()
ax_dimsout.set_xlim((0,2))
ax_dimsin.set_ylim((2,0))   
ax.set_ylim(y_lim)
ax.set_xlim(x_lim)

ax_dimsin.set_yticks(np.arange(0,3))
ax_dimsout.set_xticks(np.arange(0,3))
ax_dimsin.set_yticklabels(["$0$","$1$",r""],zorder=0)
ax_dimsout.set_xticklabels(["$0$","$1$",r"$2\,.$"],zorder=0)

plt.figtext(0.3,0.72,'Iteration',rotation=-45,\
                         horizontalalignment='right', verticalalignment='center',rotation_mode='anchor')

plt.savefig("example_permute.pdf",bbox_inches = 'tight',bbox="tight")
bbox = fig.get_tightbbox(fig.canvas.get_renderer()) 

In [None]:
bbox.width/w