In [1]:
import cvxpy as cp
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt

import time

import utils
import spectral_nti as snti

%matplotlib qt

SEED = 0
np.random.seed(SEED)

In [2]:
def plot_err(err, alphas, betas, gammas, bounds, label='A'):
    if len(bounds) == 0:
        # Without upper bounds, gamma does not matter
        plt.figure()
        plt.imshow(err[0,:,:])
        plt.colorbar()
        plt.xlabel('Alpha')
        plt.xticks(np.arange(len(alphas)), alphas)
        plt.ylabel('Beta')
        plt.yticks(np.arange(len(betas)), betas)
        plt.title('Err {}, Gamma: {}'.format(label, gammas[0]))
    else:
        # With upper bounds, gamma matter
        for k, alpha in enumerate(alphas):
            plt.figure()
            plt.imshow(err[:,:,k])
            plt.colorbar()
            plt.xlabel('Beta')
            plt.xticks(np.arange(len(betas)), betas)
            plt.ylabel('Gamma')
            plt.yticks(np.arange(len(gammas)), gammas)
            plt.title('Err {}, Alpha: {}'.format(label, alphas[k]))

In [31]:
# Regs
alphas = [0]  # [1e-4, 1e-2, .1, .5, 1]
betas = [1.25]  # np.arange(.25, 3.1, .25) # np.insert(np.array([5, 10, 25]), 0, np.arange(.25, 1.6, .25))
#gammas = [10, 50, 100, 150, 200, 250, 500, 750, 1000]  # [1, 10, 25, 50, 100]
gammas = [0, 100]

# Model params
iters = 200
M = 1000

deltas  = [4e-2, .27, 2e-2]
# deltas = [3e-3]
# deltas = [0]

gs = [
    lambda a, b : cp.sum(a)/b,    # delta: 7e-2
    lambda a, b : cp.sum(a**2)/b,  # delta: .7
    #lambda a, b : cp.sum(cp.exp(-a))/b,    # delta: 3e-3
    lambda a, b : cp.sum(cp.sqrt(a))/b,  # delta: 2e-2
]
bounds = [
    lambda lamd, lamd_t, b : -2/b*lamd_t.T@lamd,
    lambda lamd, lamd_t, b : 1/b*cp.exp(-lamd_t).T@lamd,
    lambda lamd, lamd_t, b : cp.sum(lamd/cp.sqrt(lamd_t))/(2*b),
]

# Ref graph params
n01 = 15
n02 = 10

# Target graph params 
n1 = 20
n2 = 10

In [32]:
# Create graphs
N0 = n01*n02
A0 = nx.to_numpy_array(nx.grid_2d_graph(n01, n02))
L0 = np.diag(np.sum(A0, 0)) - A0
lambdas0, _ = np.linalg.eigh(L0)

N = n1*n2
A = nx.to_numpy_array(nx.grid_2d_graph(n1, n2))
L = np.diag(np.sum(A, 0)) - A
lambdas, V = np.linalg.eigh(L)

A_n = np.linalg.norm(A,'fro')**2
lambs_n = np.linalg.norm(lambdas)**2

# Create C
lambdas_aux = np.insert(1/np.sqrt(lambdas[1:]),0,0)
C_inv_sqrt = V@np.diag(lambdas_aux)@V.T
X = C_inv_sqrt@np.random.randn(N, M)
C_hat = X@X.T/M

# Get values from the reference graph
cs = utils.compute_cs(gs, lambdas0, lambdas)

	c-0: c: 3.700	c0: 3.667	err: 0.033333	err norm: 0.009091
	c-1: c: 17.640	c0: 17.387	err: 0.253333	err norm: 0.014571
	c-2: c: 1.828	c0: 1.817	err: 0.010217	err norm: 0.005622


In [33]:
regs = {'alpha': 0, 'beta': 0, 'gamma': 0, 'deltas': deltas}

t = time.time()
err_A = np.zeros((len(gammas), len(betas), len(alphas)))
err_lam = np.zeros((len(gammas), len(betas), len(alphas)))
err_lam2 = np.zeros((len(gammas), len(betas), len(alphas)))
for k, alpha in enumerate(alphas):
    regs['alpha'] = alpha
    print('Alpha:', alpha)
    for j, beta in enumerate(betas):
        regs['beta'] = beta
        print('\tBeta:', beta)
        for i, gamma in enumerate(gammas):
            regs['gamma'] = gamma
            try:
                L_hat, lam_hat, _ = snti.SGL_MM(C_hat, gs, bounds, cs,
                                                regs, max_iters=iters)
            except cp.SolverError:
                err_A[i,j,k] = err_lam[i,j,k] = err_lam2[i,j,k] = 1
                print('\t\tGamma: {}: Solver Error'.format(gamma))
                continue
            A_hat = np.diag(np.diag(L_hat)) - L_hat
            lamd2_hat, _ = np.linalg.eigh(L_hat)
            
            err_A[i,j,k] = np.linalg.norm(A-A_hat,'fro')**2/A_n
            err_lam[i,j,k] = np.linalg.norm(lambdas-lam_hat)**2/lambs_n
            err_lam2[i,j,k] = np.linalg.norm(lambdas-lamd2_hat)**2/lambs_n
            print('\t\tGamma: {}: ErrA: {:.3f}'.format(gamma, err_A[i,j,k]))
            
t = time.time() - t
print('-----', str(t/60), 'mins -----')

Alpha: 0
	Beta: 1.25
		Gamma: 0: ErrA: 0.032
CONVERGENCE ACHIEVED
		Gamma: 100: ErrA: 0.033
----- 2.2286606272061666 mins -----


In [34]:
path = 'results\samples_grid_graph'
f1 = '\heat_tight_errA'
f2 = '\heat_tight_errL'
#np.save(path+f1, err_A)
#np.save(path+f2, err_lam2)

In [35]:
# Print results - Using pinv(C)
idx = np.unravel_index(np.argmin(err_A), err_A.shape)
print('Min err A (Alpha: {:.3g}, Beta: {:.3g}, Gamma: {:.3g}): {:.6f}\t Err in Lamb2: {:.6f}'
      .format(alphas[idx[2]], betas[idx[1]], gammas[idx[0]], err_A[idx], err_lam2[idx]))

idx = np.unravel_index(np.argmin(err_lam), err_lam.shape)
print('Min err Lambd (Alpha: {:.3g}, Beta: {:.3g}, Gamma: {:.3g}): {:.6f}\t Err in A: {:.6f}'
      .format(alphas[idx[2]], betas[idx[1]], gammas[idx[0]], err_lam[idx], err_A[idx]))

idx = np.unravel_index(np.argmin(err_lam2), err_lam2.shape)
print('Min err Lambd2 (Alpha: {:.3g}, Beta: {:.3g}, Gamma: {:.3g}): {:.6f}\t Err in A: {:.6f}'
      .format(alphas[idx[2]], betas[idx[1]], gammas[idx[0]], err_lam2[idx], err_A[idx]))

Min err A (Alpha: 0, Beta: 1.25, Gamma: 0): 0.032332	 Err in Lamb2: 0.005854
Min err Lambd (Alpha: 0, Beta: 1.25, Gamma: 100): 0.002003	 Err in A: 0.032550
Min err Lambd2 (Alpha: 0, Beta: 1.25, Gamma: 100): 0.002767	 Err in A: 0.032550


In [22]:
# Print results - Using pinv(C)
idx = np.unravel_index(np.argmin(err_A), err_A.shape)
print('Min err A (Alpha: {:.3g}, Beta: {:.3g}, Gamma: {:.3g}): {:.6f}\t Err in Lamb2: {:.6f}'
      .format(alphas[idx[2]], betas[idx[1]], gammas[idx[0]], err_A[idx], err_lam2[idx]))

idx = np.unravel_index(np.argmin(err_lam), err_lam.shape)
print('Min err Lambd (Alpha: {:.3g}, Beta: {:.3g}, Gamma: {:.3g}): {:.6f}\t Err in A: {:.6f}'
      .format(alphas[idx[2]], betas[idx[1]], gammas[idx[0]], err_lam[idx], err_A[idx]))

idx = np.unravel_index(np.argmin(err_lam2), err_lam2.shape)
print('Min err Lambd2 (Alpha: {:.3g}, Beta: {:.3g}, Gamma: {:.3g}): {:.6f}\t Err in A: {:.6f}'
      .format(alphas[idx[2]], betas[idx[1]], gammas[idx[0]], err_lam2[idx], err_A[idx]))

Min err A (Alpha: 0, Beta: 0.75, Gamma: 1): 0.099765	 Err in Lamb2: 0.007466
Min err Lambd (Alpha: 0, Beta: 0.75, Gamma: 1): 0.007110	 Err in A: 0.099765
Min err Lambd2 (Alpha: 0, Beta: 0.75, Gamma: 1): 0.007466	 Err in A: 0.099765


In [8]:
plot_err(err_A, alphas, betas, gammas, bounds)

In [9]:
plot_err(err_lam2, alphas, betas, gammas, bounds, label='Lambd2')