In [None]:
# Essential
import numpy as np
import math
import random
import copy
import pickle

from matplotlib import pyplot as plt
import matplotlib as mpl
import matplotlib.ticker as ticker
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from matplotlib.ticker import FormatStrFormatter
from matplotlib.lines import Line2D

# Pytorch
import torch
seed = 42
torch.manual_seed(seed)
from torch.utils.data import TensorDataset, DataLoader
import torchvision.transforms as Transform
import torch.optim.lr_scheduler as lr_scheduler
import scipy
from scipy.stats import ortho_group

# Display support
import cv2
from fastprogress.fastprogress import master_bar, progress_bar

%matplotlib inline
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.family'] = 'monospace' # windows
# mpl.rcParams['font.family'] = 'Helvetica' # iOS
mpl.rcParams['font.size'] = 12
mpl.rcParams['figure.dpi'] = 350
mpl.rcParams["text.usetex"]
colors = ['C1', 'C6', 'C2', 'C9', 'C10', 'C3', 'C4', 'C7', 'C8', 'C5']*10
markers = [r'$\bigcirc$', r'$\boxdot$', r'$\bigtriangleup$', r'$\heartsuit$', r'$\diamondsuit$', 'v']*10

In [None]:
def dot(A, X):
    m = A.shape[0]
    X = X.repeat(m, 1, 1)
    AX = torch.sum(A * X, dim=[1, 2])
    return AX

def spectral_init(n, Phi, PsiT): # Assumption 2 in our paper
    if False:
        G = torch.tensor(ortho_group.rvs(n), dtype=torch.float) # kills kernel
    else:
        Gtemp = torch.rand(n,n)
        Gtemp1, _, Gtemp2 = torch.linalg.svd(Gtemp)
        G = Gtemp1 #@Gtemp2
    Ubar, _ = torch.sort(torch.rand(n))
    Vbar, _ = torch.sort(torch.rand(n))
    U_init = Phi @ torch.diag(Ubar) @ G
    V_init = PsiT.T @ torch.diag(Vbar) @ G
    return U_init, V_init

def spectral_nonlinear(Q):
    # Apply a non-linear on the spectrum of Q
    U, S, Vt = torch.linalg.svd(Q, full_matrices=False)
    S = torch.diag_embed(torch.nn.functional.tanh(S))
    USVt = torch.bmm(torch.bmm(U, S), Vt)
    return USVt

def mse_loss(X, A, y):
    return 0.5 * torch.mean(torch.square(dot(A, X) - y))

def test_error(X, Xtrue):
    return torch.norm(X - Xtrue, p='fro') / torch.norm(Xtrue, p='fro')

def erank(X):
    S = torch.linalg.svdvals(X)
    p = torch.div(S,torch.sum(S))
    efr = torch.exp(-torch.sum(torch.mul(p,torch.log(p))))
    return efr
def get_norm_grad(model):
    try:
        grad_norm = 0
        for param in model.parameters():
            grad_norm = grad_norm + torch.square(torch.norm(param.grad, p=2))
        grad_norm = torch.sqrt(grad_norm)
    except:
        grad_norm = torch.tensor(float('inf'))
    return grad_norm

def PSNR(original, compressed):
    mse = torch.mean(torch.square(original - compressed)).detach().numpy()
    if(mse == 0):  # MSE is zero means no noise is present in the signal .
                  # Therefore PSNR have no importance.
        return 100
    max_pixel = 255.0
    psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
    return psnr
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
def plot_matrix(inputmat, inputaxes = None, title = None):
    if inputaxes is None:
        figout, axesout = plt.subplots(nrows=1, ncols=1,figsize=(2,1.5), linewidth=10, edgecolor="#04253a")
    else:
        axesout = inputaxes
    try:
        inputmat = torch.clamp(inputmat, min=0, max=255)
        inputmat = inputmat.detach().numpy().astype('uint8')
        axesout.imshow(cv2.cvtColor(inputmat, cv2.COLOR_BGR2RGB))
        axesout.axis("off")
        axesout.title.set_text(title)
    except:
        print("An error in ploting matrix")
    return
def print_output(X,A_train,y_train,A_test,y_test,Xtrue):
    nuc_norm = torch.norm(X,p='nuc')
    print(f'nuclear norm: {nuc_norm:.2f}, erank: {erank(X):.2f}, rank: {torch.linalg.matrix_rank(X):.2f}')
    print(f'train loss: {mse_loss(X,A_train,y_train):.2e}, test loss: {mse_loss(X,A_test,y_test):.2e}')
    print(f'test error <recovering error>: {test_error(X,Xtrue):.2e}')
    return

In [None]:
def run_gradient_descent(model, A_train, y_train, A_test, y_test,
                         lr=1e-4, num_iter=10000, repeat = 1,
                         freq=5, batch_size = None, momentum=0,
                         Xtrue= None, scheduler_flag=False):

    target_best = float('inf')
    target_traces = []
    model_org = copy.deepcopy(model)
    if batch_size is None:
        batch_size = A_train.shape[0]
    if lr == 'search':
        lr_list = [ ii*10**rate for rate in range(1,-12,-1) for ii in [7.5,5.0,2.5,1.0]]
        return_tag = True
    elif type(lr) is list:
        lr_list = lr
        return_tag = False
    else:
        lr_list = [lr]
        return_tag = False


    for rep in (mbar:= master_bar(range(repeat*len(lr_list)))):
        lr = lr_list[rep//repeat]
        try:
            model = copy.deepcopy(model_org)
            optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum = momentum)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,num_iter//111)
            train_loss_traces, test_loss_traces, test_error_traces = [], [], []
            singular_value_traces, nuclear_norm_traces, erank_traces = [],[],[]
            grad_traces, lr_traces, top_sv_traces, bot_sv_traces, X_traces, iter_traces = [], [], [], [], [], []
            batch_loader = DataLoader(torch.utils.data.TensorDataset(A_train, y_train),
                                      batch_size=batch_size, shuffle=True)
            # train
            model.train()
            for i in (pbar:= progress_bar(range(num_iter), parent=mbar)):
                (A_batch,y_batch) = next(iter(batch_loader))
                X_temp = model()
                loss = mse_loss(X_temp,A_batch,y_batch)
                optimizer.zero_grad()
                if scheduler_flag:
                    scheduler.step()
                loss_train = mse_loss(X_temp, A_train, y_train)
                if torch.isnan(loss_train) or torch.isinf(loss_train):
                    raise Exception('numerical error')
                # record
                if i % freq == 0:
                    iter_traces.append(i)
                    train_loss_traces.append(loss_train.detach().numpy())
                    loss_test = mse_loss(X_temp, A_test, y_test)
                    test_loss_traces.append(loss_test.detach().numpy())


                    nucnorm = torch.norm(X_temp,p='nuc').detach().numpy()
                    svs = torch.linalg.svdvals(X_temp).detach().numpy()
                    top_sv_traces.append(sorted(svs)[-3:])
                    bot_sv_traces.append(sorted(svs)[:2])
                    singular_value_traces.append(svs)
                    nuclear_norm_traces.append(nucnorm)

                    erank_value = erank(X_temp)
                    erank_traces.append(erank_value.detach().numpy())
                    X_traces.append(X_temp)
                    grad_traces.append(get_norm_grad(model).detach().numpy())
                    cur_lr = scheduler.get_last_lr()[0]
                    lr_traces.append(cur_lr)
                    try:
                        test_error_val = test_error(X_temp,Xtrue).detach().numpy()
                    except:
                        test_error_val = 1.0
                    test_error_traces.append(test_error_val)
                    # record
                    traces = {  '_': iter_traces,
                        'train_loss_traces': train_loss_traces,
                        'test_loss_traces': test_loss_traces,
                        'test_error_traces': test_error_traces,
                        'erank_traces': erank_traces,
                        'nuclear_norm_traces': nuclear_norm_traces,
                        'top_sv_traces': top_sv_traces,
                        'bot_sv_traces': bot_sv_traces,
                        'grad_traces': grad_traces,
                        'lr_traces': lr_traces,
                        'singular_value_traces': singular_value_traces,
                        'X_traces': X_traces  }


                loss.backward()
                optimizer.step()
                # print
                mbar.child.comment = f' ] [train {loss_train:.2e}, nuc-nrm {nucnorm:.2f}, lr {cur_lr:.2e}, test err {test_error_val:.2e}'


            model.eval()
            target_val = loss_train.detach().numpy()
            success_flag = True
        except Exception as emes:
#             print(emes)
            target_val = float('inf')
            success_flag = False
            pass
        # Compare
        target_traces.append(target_val)
        if success_flag and (target_val < target_best or rep == 0):
            model_best = copy.deepcopy(model)
            target_best = target_val
            traces_best = traces
            lr_best = lr
            # print
            mbar.main_bar.comment = f'] [best train {target_val:.2e} <{lr_best:.2e}>'
        elif success_flag and return_tag and target_val > target_best:
            return traces_best, model_best(), model_best, lr_best, target_traces
    return traces_best, model_best(), model_best, lr_best, target_traces

In [None]:
class ParamsInit(torch.nn.Module):
    def __init__(self, n, in_dim=[1], init_tag = None, Phi=None, PsiT=None):
        super(ParamsInit, self).__init__()
        self.n = n
        self.first_dim = in_dim[0]
        self.all_dim = in_dim + [1]
        self.n_layers = len(in_dim)
        self.Us = torch.nn.ParameterList()
        self.Vs = torch.nn.ParameterList()
        self.alphas = torch.nn.ParameterList()
        self.init_tag = init_tag
        # Initialize bias
        self.bias = torch.nn.parameter.Parameter(torch.zeros(n,n))
        # Initialize Us and Vs
        for i in range(self.first_dim):
            if init_tag == 'well_spec':
                U, V = spectral_init(n, Phi, PsiT)
            elif init_tag == 'random':
                U = torch.rand(n,n)
                V = torch.rand(n,n)
            elif init_tag == 'diag':
                U = torch.diag(torch.rand(n))
                V = torch.diag(torch.rand(n))
            elif init_tag == 'identity':
                eps = 10**-4
                U, V = torch.eye(n)*eps,torch.eye(n)*eps
            U = torch.nn.Parameter(U)
            V = torch.nn.Parameter(V)
            self.Us.append(U)
            self.Vs.append(V)
        # Initialize alphas
        for i in range(self.n_layers):
            alpha_i = torch.nn.Parameter(torch.rand(self.all_dim[i],self.all_dim[i+1]))
            self.alphas.append(alpha_i)
class SNN(torch.nn.Module):
    def __init__(self, model_ref=None):
        super(SNN, self).__init__()
        self.n = model_ref.n
        self.first_dim = model_ref.first_dim
        self.all_dim = model_ref.all_dim
        self.model_name = 'snn_' + model_ref.init_tag+'_'+str(self.all_dim)
        self.n_layers = model_ref.n_layers
        self.Us = model_ref.Us
        self.Vs = model_ref.Vs
        if model_ref.init_tag == 'identity':
            for i, (U,V) in enumerate(zip(self.Us,self.Vs)):
                self.Us[i] = U + torch.diag_embed(1e-6*torch.randn(self.n))
                self.Vs[i] = V + torch.diag_embed(1e-6*torch.randn(self.n))

        self.alphas = model_ref.alphas
    def forward(self):
        X = [torch.matmul(U , V.T) for U, V in zip(self.Us, self.Vs)]
        X = torch.stack(X, axis=0)
        for i in range(self.n_layers):
            X = spectral_nonlinear(X) 
            X_T = torch.permute(X,(1,2,0))
            X = torch.permute(torch.matmul(X_T,self.alphas[i]), (2,0,1))
        X = torch.squeeze(X)
        return X
class WOA(torch.nn.Module):
    def __init__(self, model_ref=None):
        super(WOA, self).__init__()
        self.model_name = 'woa_' + model_ref.init_tag +'_D'+str(self.first_dim)
        self.n = model_ref.n
        self.first_dim = model_ref.first_dim
        self.all_dim = model_ref.all_dim
        self.n_layers = model_ref.n_layers
        self.Us = model_ref.Us
        self.Vs = model_ref.Vs
        self.alphas = model_ref.alphas
    def forward(self):
        X = [torch.matmul(U , V.T) for U, V in zip(self.Us, self.Vs)]
        X = torch.stack(X, axis=0)
        for i in range(self.n_layers):
            X_T = torch.permute(X,(1,2,0))
            X = torch.permute(torch.matmul(X_T,self.alphas[i]), (2,0,1))
        X = torch.squeeze(X)
        return X
class DF(torch.nn.Module):
    def __init__(self, model_ref=None, depth=None):
        super(DF, self).__init__()
        self.n = model_ref.n
        self.first_dim = model_ref.first_dim
        Us = model_ref.Us
        Vs = model_ref.Vs
        if depth is None:
            self.depth = self.first_dim * 2
        else:
            self.depth = depth
        self.model_name = 'df_' + model_ref.init_tag+'_'+str(self.depth)

        self.Ws = torch.nn.ParameterList()
        for i in range(self.first_dim):
            self.Ws.append(torch.nn.Parameter(Us[i]))
            self.Ws.append(torch.nn.Parameter(torch.transpose(Vs[i],0,1)))
        self.Ws = self.Ws[:self.depth]
    def forward(self):
        X = torch.eye(self.n)
        for i in range(self.depth):
            X = torch.matmul(X,self.Ws[i])
        X = torch.squeeze(X)
        return X

class LG(torch.nn.Module):
    def __init__(self, model_ref=None):
        super(LG, self).__init__()
        self.model_name = 'lg'
        self.n = model_ref.n
        self.W = torch.nn.Parameter(torch.zeros(self.n,self.n))
    def forward(self):
        return self.W

In [None]:
def assign_subtensor(bigmat,row_idx,col_idx,submat):
    for smallr,bigr in enumerate(row_idx):
        for smallc,bigc in enumerate(col_idx):
            bigmat[bigr,bigc] = submat[smallr,smallc]
    return bigmat

def generate_X(n, Xlabel=0, nlow=None, Xtrue = None, image_path = None, prescale=True):
    if Xlabel == 'synthetic': # random, PSD
        if nlow is None:
            nlow = n
        Utemp = torch.rand(n,nlow)
        Xtrue = Utemp @ Utemp.T
        Xtrue = Xtrue * 100.0
    elif Xlabel == 'predefined': # pre-defined Xtrue
        Xtrue = torch.from_numpy(Xtrue).float()
        Xtrue_unsq = Xtrue[None, :, :]
        avg_fun = Transform.Resize((n,n))
        Xtrue = avg_fun(Xtrue_unsq).squeeze()

    # Scale
    if prescale:
        X_scale = torch.norm(Xtrue,p='nuc')
    else:
        X_scale = 1
    Xtrue /= X_scale
    Phi, X_svd_vec, PsiT = torch.linalg.svd(Xtrue)

    return  Xtrue, X_scale, Phi, PsiT, X_svd_vec


def generate_A(Xtrue, Phi=None, PsiT=None, m=10, Alabel=0, noise=True):

    n = Xtrue.shape[0]

    A = []
    A_svd_vec = []
    for i in (pbar:= progress_bar(range(2*m))):
        if Alabel == 'well_spec':
            A_svd_vec_i = torch.randn(n)*n
            A_svd_vec_i, _ = torch.sort(A_svd_vec_i, descending=True)
            Ai = Phi @ torch.diag(A_svd_vec_i) @ PsiT
        elif Alabel == 'mis_spec':
            Ai = torch.randn(n, n)*n


        A.append(Ai)
        A_svd_vec_i = torch.linalg.svdvals(Ai)
        A_svd_vec.append(A_svd_vec_i)

    A = torch.stack(A)
    y = dot(A, Xtrue)
    if noise:
        y = y + torch.randn(2*m)*1e-2

    A_train = A[:m]
    y_train = y[:m]
    A_test  = A[m:]
    y_test  = y[m:]
    return A_train, y_train, A_test, y_test, A_svd_vec

# Generate dataset & built networks

In [None]:
'''Choose one of the following line'''
# Example, Assumption, Noisy, Figure = 'syn', 'well_spec', True, 'Figure1' # Figure 1
# Example, Assumption, Noisy, Figure = 'syn', 'mis_spec', True, 'Figure2' # Figure 2
Example, Assumption, Noisy, Figure, Findex = 'dgt', 'mis_spec', True, 'Figure3', 0 # Figure 3
'''--------------------------------------------------------------------------'''
if __name__ == "__main__":


    # Generate X
    if Example == 'syn':
        n, nlow = 10, 6
        Xtrue, X_scale, Phi, PsiT, _ = generate_X(n=n, Xlabel='synthetic', nlow=nlow)
        nlow = torch.linalg.matrix_rank(Xtrue)
        if Figure == 'Figure1':            
            num_iter = 100001
        else:
            num_iter = 10001
        m = 75
        batch_size = m
        num_lr = 5
    elif Example == 'dgt':
        from keras.datasets import mnist, fashion_mnist
        (Xmnist, _), (_, _) = mnist.load_data()
        X_dgt = Xmnist[Findex,:,:]
        n = X_dgt.shape[0]
        Xtrue, X_scale, Phi, PsiT, _ = generate_X(n=n, Xlabel='predefined', Xtrue=X_dgt)
        num_iter = 2501
        m = round(n*n*0.6)
        batch_size = m
        num_lr = 3

    # Generate A
    if Assumption == 'well_spec':
        A_train, y_train, A_test, y_test, _ = generate_A(Xtrue, Phi=Phi, PsiT=PsiT, m=m, Alabel='well_spec', noise=Noisy)
    elif Assumption == 'mis_spec':
        Phi, PsiT = None, None
        A_train, y_train, A_test, y_test, _ = generate_A(Xtrue, m=m, Alabel='mis_spec', noise=Noisy)


    # Built networks
    if Figure == 'Figure1' or Figure == 'Figure2':
        in_dim = [2,2,2,2]
    else:
        in_dim = [4,8,16]

    

    if Figure == 'Figure1':
        params_init_w = ParamsInit(n, in_dim=in_dim, Phi=Phi, PsiT=PsiT, init_tag='well_spec') # well spec
        model_snn_w = SNN(params_init_w)
        list_model_to_run = [model_snn_w]
    elif Figure == 'Figure2':
        params_init_i = ParamsInit(n, in_dim=in_dim, Phi=None, PsiT=None, init_tag='identity') # identity
        model_snn_i = SNN(params_init_i)
        
        model_df_i3 = DF(params_init_i, depth=3)
        
        model_lg = LG(params_init_i)
        list_model_to_run = [model_snn_i,model_lg,model_df_i3]
    elif Figure == 'Figure3':
        params_init_i = ParamsInit(n, in_dim=in_dim, Phi=None, PsiT=None, init_tag='identity') # identity
        model_snn_i = SNN(params_init_i)
        
        model_lg = LG(params_init_i)
        list_model_to_run = [model_lg,model_snn_i]

    # Display
    print('+'*100)
    nucnorm = torch.norm(Xtrue,p='nuc')
    print(f'n = {n}, nlow = {torch.linalg.matrix_rank(Xtrue)}, erank = {erank(Xtrue):.2f}, nuclear norm: {nucnorm:.2f}'
          f'\nm {m}, batch {batch_size}'
          f'\ntrain loss: {mse_loss(Xtrue,A_train,y_train):.2e}, test loss: {mse_loss(Xtrue,A_test,y_test):.2e}'
          f'\nPhi {Phi}, \nPsiT {PsiT} \nXtrue {Xtrue}\nAi {A_train[0]}\nyi {y_train[0]}')
    plot_matrix(Xtrue*X_scale, title='ground-truth')
    for model in list_model_to_run:
        print('+'*100)
        print(f'model {model.model_name}: no of params {count_parameters(model)}\ninitalization X {model()}')

# Execute

In [None]:
if __name__ == "__main__":
    problem_id = Example+'_'+Assumption+'_'+str(Noisy)
    model_names = [model.model_name for model in list_model_to_run]
    lrs = [ ii*10**rate for rate in range(1,-12,-1) for ii in [5.0,1.0]]


    print(f'problem id {problem_id}\nm = {m}, n = {n}, iter = {num_iter}, batch = {batch_size}'
          f'\nmodel: {model_names}\nlrs: {lrs}')

    Outputs = {'problem_id':problem_id,
#                'params_init':[params_init_r,params_init_i],
               'list_model': list_model_to_run,
               'model_names': model_names,
               'Xtrue': Xtrue,
               'X_scale': X_scale,
              'A_train': A_train,
              'y_train': y_train,
              'A_test': A_test,
              'y_test': y_test,
#               'in_dim': in_dim
              }

    for m_count, model in enumerate(list_model_to_run):        
        key_out = model.model_name
        print('\n\n\n'+'+'*100+'\n')

        results = {}
        lr_count=0


        if 'df' in key_out:
            factor = 2
        else:
            factor = 1


        for lr in lrs:
            if lr_count <  num_lr*math.ceil(factor):
                print(f'model {key_out}, lr = {lr:.2e}, {m_count} of {len(list_model_to_run)} models')
                try:
                    model_in = copy.deepcopy(model)
                    traces_out, X_out, model_out, lr_out, target_traces = run_gradient_descent(
                        model_in, A_train, y_train, A_test, y_test,
                        lr=lr, num_iter=round(num_iter*factor),
                        freq=round(num_iter*factor)//100, batch_size=batch_size,
                        repeat = 1, Xtrue = Xtrue)
                    results[lr] = traces_out
                    print_output(X_out,A_train,y_train,A_test,y_test,Xtrue)
                    lr_count += 1
                except Exception as emes:
                    print(emes)
                print('-'*100)

                # save
                Outputs[key_out] = results
                
                if Figure=='Figure3':
                    file_save_name = 'output/'+Figure+'_'+str(m)+'_'+str(Findex)+'_'+problem_id+'.pkl', 'wb'
                else:
                    file_save_name = 'output/'+Figure+'_'+str(m)+'_'+problem_id+'.pkl', 'wb'
                with open('output/'+Figure+'_'+str(m)+problem_id+'.pkl', 'wb') as file:
                    pickle.dump(Outputs, file)