In [None]:
from tvsclib.strict_system import StrictSystem
from tvsclib.stage import Stage
from tvsclib.system_identification_svd import SystemIdentificationSVD
from tvsclib.toeplitz_operator import ToeplitzOperator
from tvsclib.mixed_system import MixedSystem
import numpy as np
import scipy.linalg as linalg
import matplotlib.pyplot as plt
import scipy.linalg 
import scipy.stats 
import tvsclib.utils as utils
import tvsclib.math as math

import setup_plots
import move

In [None]:
setup_plots.setup()
plt.rcParams['figure.dpi'] = 150

In [None]:
np.random.seed(14000)

In [None]:
dims_in =  np.array([6, 3, 5, 2])*3
dims_out = np.array([2, 5, 3, 6])*3

#create orthogonal vectors and normalize them to the size of the matix (i.e. norm(block)/size(block) = const
#Us =np.vstack([np.linalg.svd(np.random.rand(dims_out[i],dims_in[i]))[0][:,1:4]*dims_out[i] for i in range(len(dims_in))])
#Vts=np.hstack([np.linalg.svd(np.random.rand(dims_out[i],dims_in[i]))[2][1:4,:]*dims_in[i] for i in range(len(dims_in))])

n = 1
#create orthogonal vectors and normalize them to the size of the matix (i.e. norm(block)/size(block) = const
Us =np.vstack([scipy.stats.ortho_group.rvs(dims_out[i])[:,:3*n]*dims_out[i] for i in range(len(dims_in))])
Vts=np.hstack([scipy.stats.ortho_group.rvs(dims_in[i])[:3*n,:]*dims_in[i] for i in range(len(dims_in))])


lower = Us[:,:n]@Vts[:n,:]
diag = Us[:,n:2*n]@Vts[n:2*n,:]
upper = Us[:,2*n:3*n]@Vts[2*n:3*n,:]
matrix = np.zeros_like(diag)
a=0;b=0
for i in range(len(dims_in)):
    matrix[a:a+dims_out[i],:b]            =lower[a:a+dims_out[i],:b]
    matrix[a:a+dims_out[i],b:b+dims_in[i]]=diag[a:a+dims_out[i],b:b+dims_in[i]]
    matrix[a:a+dims_out[i],b+dims_in[i]:] =upper[a:a+dims_out[i],b+dims_in[i]:]
    a+=dims_out[i];b+=dims_in[i]


In [None]:
dims_in_start = [sum(dims_in)//4]*4
dims_out_start = [sum(dims_out)//4]*4
T = ToeplitzOperator(matrix, dims_in_start, dims_out_start)
S = SystemIdentificationSVD(T,epsilon=1e-12)
system = MixedSystem(S)


In [None]:
def cost_nuc(s,s_a):
    tau = 0.9
    c = np.sum(s)/np.max(s)+\
        np.sum(s_a)/np.max(s_a)
    #print("s:",s[s<tau*np.max(s)]," s_a:",s_a[s_a<tau*np.max(s_a)])
    return c

m_in=[4,2,1,1]
m_out=[4,2,1,1]

sys_move,input_dims,output_dims,fs = move.move(system,None,cost_nuc,m_in = m_in,m_out=m_out)

print(dims_in)
print(np.array(sys_move.dims_in))
print(dims_out)
print(np.array(sys_move.dims_out))

In [None]:
w = setup_plots.textwidth
#plt.figure(figsize = [w,w*3/4])
fig, axes = plt.subplots(ncols=2 , figsize = [w,w/2])
utils.show_system(system,ax=axes[0])
utils.show_system(sys_move,ax=axes[1])
for ax in axes:
    ax.set_xticks(np.cumsum(dims_in[:-1])-0.5)
    ax.set_yticks(np.cumsum(dims_out[:-1])-0.5)
    ax.set_xticklabels([' ']*(len(dims_in)-1))
    ax.set_yticklabels([' ']*(len(dims_in)-1))
    ax.set_xticks(np.cumsum(dims_in)-0.5*dims_in, minor=True)
    ax.set_xticklabels(["$"+str(d)+"$" for d in dims_in], minor=True)
    ax.set_yticks(np.cumsum(dims_out)-0.5*dims_out-1, minor=True)
    ax.set_yticklabels(["$"+str(d)+"$" for d in dims_out], minor=True)
    
    ax.tick_params(which='major', length=7)
    ax.tick_params(which='minor', length=1, color='w')
    ax.xaxis.set_ticks_position('top')
    ax.yaxis.set_ticks_position('left')

    #ax.set_xlabel("Input dimensions")
    #plt.ylabel(" ")
plt.savefig("example_move.pdf",bbox="tight")

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable


w = setup_plots.textwidth
fig, ax = plt.subplots(figsize=(0.62*w, .62*w))

utils.show_system(sys_move,ax=ax)
y_lim = ax.get_ylim()
x_lim = ax.get_xlim()
ax.xaxis.set_ticks_position('top')

divider = make_axes_locatable(ax)
ax_dimsin = divider.append_axes("top", 0.68, pad=0.1, sharex=ax)
ax_dimsout = divider.append_axes("left", 0.68, pad=0.1, sharey=ax)

# make some labels invisible
ax_dimsin.xaxis.set_tick_params(labelbottom=False)
ax_dimsout.yaxis.set_tick_params(labelright=False)


ax_dimsin.invert_yaxis()

angl = np.array([0.1,-0.1]*5)#add vector to make the lines slightly angled

din_cum=np.cumsum(input_dims,axis=0)
dout_cum=np.cumsum(output_dims,axis=0)
for i in range(dout_cum.shape[0]-1):
    ax_dimsout.plot(np.repeat(np.arange(dout_cum.shape[1]+1),2)[1:-1]+angl,
                    np.repeat(dout_cum[i,:],2)-0.5,\
                   linestyle='solid',color='C0')

din_cum=np.cumsum(input_dims,axis=0)
for i in range(din_cum.shape[0]-1):
    ax_dimsin.plot(np.repeat(din_cum[i,:],2)-0.5,
                   np.repeat(np.arange(din_cum.shape[1]+1),2)[1:-1]+angl,\
                  linestyle='solid',color='C0')


ax_dimsout.xaxis.set_ticks_position('top')
ax_dimsout.yaxis.set_ticks_position('right')
ax_dimsout.yaxis.set_tick_params(labelright=False)

ax_dimsin.set_yticks(np.arange(1,5))
ax_dimsout.set_xticks(np.arange(1,5))

ax_dimsin.set_xticks(np.arange(3,48,3)-0.5)
ax_dimsout.set_yticks(np.arange(3,48,3)-0.5)

ax_dimsin.grid()
ax_dimsout.grid()
ax_dimsout.set_xlim((0,5))
ax_dimsin.set_ylim((5,0))   
ax.set_ylim(y_lim)
ax.set_xlim(x_lim)

plt.figtext(0.3,0.72,'Iteration',rotation=-45,\
                         horizontalalignment='right', verticalalignment='center',rotation_mode='anchor')

s = 2
axins = ax_dimsout.inset_axes([-2.75, 0.25,s*4.8/5,s*20/48])
#axins.imshow(Z2, extent=extent, origin="lower")
# sub region of the original image
x1, x2, y1, y2 = 0.2, 4.8, 34-0.5, 14-0.5
O = np.zeros((len(m_in)+1,3**len(m_in)))
for i in range(len(m_in)):
    O[i+1:] += np.repeat(np.array([m_in[i],0,-m_in[i]]*3**(len(m_in)-i-1)),3**(i)).reshape(1,-1)

axins.plot(np.repeat(np.arange(din_cum.shape[1]+1),2)[1:-1]+angl,
           24+np.repeat(O,2,axis=0)-0.5,
                  linestyle=':',color='0.3',linewidth=1)
O=np.clip(O,None,6)
axins.plot(np.repeat(np.arange(din_cum.shape[1]+1),2)[1:-1]+angl,
           24+np.repeat(O,2,axis=0)-0.5,
                  linestyle='solid',color='0.3',linewidth=1)

din_cum=np.cumsum(input_dims,axis=0)
dout_cum=np.cumsum(output_dims,axis=0)
for i in range(dout_cum.shape[0]-1):
    axins.plot(np.repeat(np.arange(dout_cum.shape[1]+1),2)[1:-1]+angl,
                    np.repeat(dout_cum[i,:],2)-0.5,\
                   linestyle='solid',color='C0')
    

axins.set_xticks(np.arange(1,5))
axins.set_yticks(np.arange(3,48,3)-0.5)
axins.grid()

axins.set_xlim(x1, x2)
axins.set_ylim(y1, y2)
#axins.set_xticklabels([])
axins.set_yticklabels([])
axins.set_xlabel("Iteration")

#axins.set_title('Search Tree')

ax_dimsout.indicate_inset_zoom(axins, edgecolor="black")
plt.savefig("example_move_iterations.pdf",bbox_inches = 'tight',bbox="tight")
bbox = fig.get_tightbbox(fig.canvas.get_renderer()) 


In [None]:
# get the fraction of the textwidt for tex
bbox.width/w

In [None]:
np.cumsum(input_dims,axis=0)

In [None]:
plt.plot(fs)