In [None]:
from tvsclib.strict_system import StrictSystem
from tvsclib.stage import Stage
from tvsclib.system_identification_svd import SystemIdentificationSVD
from tvsclib.toeplitz_operator import ToeplitzOperator
from tvsclib.mixed_system import MixedSystem
import numpy as np
import scipy.linalg as linalg
import matplotlib.pyplot as plt
import scipy.linalg 
import scipy.stats 
import tvsclib.utils as utils
import tvsclib.math as math

import setup_plots
import move

import torchvision.models as models
import torch

In [None]:
setup_plots.setup()
plt.rcParams['figure.dpi'] = 150

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

def plot_moves(sys_move,input_dims,output_dims,fs,text_ylabel=" "):
    w = setup_plots.textwidth
    #fig, ax = plt.subplots(1,1,figsize=(1*w, .5*w))
    #fig, (ax,axf) = plt.subplots(1,2,figsize=(1*w, .5*w),gridspec_kw={'width_ratios':[2,1]})
    fig = plt.figure(figsize=(1*w, .5*w)) 
    ax = fig.add_axes([0.1,0.2,0.65,0.9]) #[left, bottom, width, height]
    axf = fig.add_axes([0.85,0.35,0.3,0.45]) #[left, bottom, width, height]

    utils.show_system(sys_move,ax=ax)
    y_lim = ax.get_ylim()
    x_lim = ax.get_xlim()
    ax.xaxis.set_ticks_position('top')

    divider = make_axes_locatable(ax)
    ax_dimsin = divider.append_axes("top", 0.68, pad=0.1, sharex=ax)
    ax_dimsout = divider.append_axes("left", 0.68, pad=0.1, sharey=ax)

    # make some labels invisible
    ax_dimsin.xaxis.set_tick_params(labelbottom=False)
    ax_dimsout.yaxis.set_tick_params(labelright=False)

    N = input_dims.shape[1]

    ax_dimsin.invert_yaxis()

    angl = np.array([0.1,-0.1]*N)#add vector to make the lines slightly angled

    din_cum=np.cumsum(input_dims,axis=0)
    dout_cum=np.cumsum(output_dims,axis=0)
    for i in range(dout_cum.shape[0]-1):
        ax_dimsout.plot(np.repeat(np.arange(dout_cum.shape[1]+1),2)[1:-1]+angl,
                        np.repeat(dout_cum[i,:],2)-0.5,\
                       linestyle='solid',color='C0')

    din_cum=np.cumsum(input_dims,axis=0)
    for i in range(din_cum.shape[0]-1):
        ax_dimsin.plot(np.repeat(din_cum[i,:],2)-0.5,
                       np.repeat(np.arange(din_cum.shape[1]+1),2)[1:-1]+angl,\
                      linestyle='solid',color='C0')


    ax_dimsout.xaxis.set_ticks_position('top')
    ax_dimsout.yaxis.set_ticks_position('right')
    ax_dimsout.yaxis.set_tick_params(labelright=False)

    spacing =2#how manx iteration maRKERS
    ax_dimsin.set_yticks(np.arange(1,N,spacing))
    ax_dimsout.set_xticks(np.arange(1,N,spacing))

    #ax_dimsin.set_xticks(np.arange(3,48,3)-0.5)
    #ax_dimsout.set_yticks(np.arange(3,48,3)-0.5)

    ax_dimsin.grid()
    ax_dimsout.grid()
    ax_dimsout.set_xlim((0,N))
    ax_dimsin.set_ylim((N,0))   
    ax.set_ylim(y_lim)
    ax.set_xlim(x_lim)

    offset = 0.1
    pos_x = ax_dimsout.get_position().xmin+0.075+offset/fig.get_figwidth()
    pos_y = ax_dimsin.get_position().ymax-offset/fig.get_figheight()
    plt.figtext(pos_x,pos_y,'Iteration',rotation=-45,\
                             horizontalalignment='left', verticalalignment='center',rotation_mode='anchor')
    
    axf.plot(fs)
    axf.grid()
    axf.set_xlabel('Iteration')
    axf.set_ylabel(text_ylabel)
    axf.ticklabel_format(axis='y',scilimits=(0,0))
    axf.set_xticks(np.arange(1,N,spacing))
    axf.set_xlim((0,N-1))
    

In [None]:
def get_mobilenet_target_mats():
    target_mats = []
    # Load the model
    model = models.mobilenet_v2(pretrained=True)
    # Put moel into eval mode
    model.eval()
    for layer in model.classifier:
        if isinstance(layer, torch.nn.Linear):
            # Obtain the weights of this layer
            weights = layer.weight.detach().numpy()
            target_mats.append(weights)
    return target_mats
mat_mobilenet = get_mobilenet_target_mats()[0]

In [None]:
mat_mobilenet.shape

In [None]:
stages = 15

#set the dims
d_in = mat_mobilenet.shape[1]
boundaries = d_in/stages*np.arange(stages+1)
boundaries = np.round(boundaries).astype(int)
dims_in = boundaries[1:]-boundaries[:-1]

d_out = mat_mobilenet.shape[0]
boundaries = d_out/stages*np.arange(stages+1)
boundaries = np.round(boundaries).astype(int)
dims_out = boundaries[1:]-boundaries[:-1]

assert sum(dims_in)==d_in and sum(dims_out)==d_out

T = ToeplitzOperator(mat_mobilenet, dims_in,dims_out)
S = SystemIdentificationSVD(T,epsilon=2e-1)
system = MixedSystem(S)
#approx =Approximation(system)
print(system)

In [None]:
def cost_computation(sigmas_causal,sigmas_anticausal,dims_in,dims_out):
    k = len(dims_in)
    dims_state_causal = np.zeros(k+1)
    dims_state_anticausal = np.zeros(k+1)
    #get the number of stages for each step
    for i in range(k-1):
        dims_state_causal[i+1] = np.count_nonzero(sigmas_causal[i]>eps)
        dims_state_anticausal[i+1] = np.count_nonzero(sigmas_anticausal[i]>eps)
    
    return math.cost(dims_in,dims_out,dims_state_causal,causal=True)\
            +math.cost(dims_in,dims_out,dims_state_anticausal,causal=False,include_D=False)

In [None]:
#eps_max = np.linalg.svd(mat_mobilenet[int(np.floor(mat_mobilenet.shape[0]/2)):,:int(np.floor(mat_mobilenet.shape[1]/2))],compute_uv=False)[0]
eps_max = math.hankelnorm(mat_mobilenet,system.dims_in,system.dims_out)
eps = eps_max*0.35
print("eps:",eps)

m_in=np.ceil(30/1.5**np.arange(10)).astype(int)
m_out=m_in
sys_move,input_dims,output_dims,fs = move.move(system,None,cost_computation,m_in=m_in,m_out=m_out,cost_global=True)
print("l=")
display(m_out)

In [None]:
plot_moves(sys_move,input_dims,output_dims,fs,text_ylabel=r'$\text{f}_{\text{FLOP}}(\Sigma)$')

plt.savefig("move_example_mobilenet_comp.pdf",bbox="tight",bbox_inches = 'tight')
bbox = plt.gcf().get_tightbbox( plt.gcf().canvas.get_renderer()) 
print(bbox.width/setup_plots.textwidth)