In [None]:
import sys
import copy
import time

import numpy as np
import torch
import matplotlib.pyplot as plt

sys.path.insert(0, '..')
from graph_deep_decoder import datasets as ds
from graph_deep_decoder import utils
from graph_deep_decoder.architecture import GraphDecoder
from graph_deep_decoder.model import Model

SEED = 0
torch.manual_seed(SEED)
np.random.seed(SEED)

# Graph parameters
Gs = {}
Gs['type'] = ds.SBM
Gs['N'] = 64
Gs['k'] = 4
Gs['type_z'] = ds.CONT
Gs['p'] = 0.7
Gs['q'] = 0.015

# Signal parameters
K = 3
n_signals = 100
params = 4
n_p = 0.1

# Model parameters
epochs = 1000
lr = 0.001

JJ_samples = 1000

# Create graph
G = ds.create_graph(Gs, SEED)
A = G.W.todense()
plt.figure()
G.plot()

# Create filter H and decoder
hs = np.random.rand(K)
hs /= np.sum(hs)
print('Filter coefs:', hs)
x_dw = ds.DiffusedWhiteGS(G, ds.NonLin.NONE, K, coefs=hs)
H = x_dw.H

# Prepare experiments
exps = [{'dec': GraphDecoder(16, H), 'leg': '16'},
        {'dec': GraphDecoder(32, H), 'leg': '32'},
        {'dec': GraphDecoder(64, H), 'leg': '64'},
        {'dec': GraphDecoder(150, H), 'leg': '150'},
        {'dec': GraphDecoder(1000, H), 'leg': '1000'},
        {'dec': GraphDecoder(5000, H), 'leg': '5000'}]
leg = [exp['leg'] for exp in exps]

start_time = time.time()
for exp in exps:
    JJ = exp['dec'].analytical_squared_jacobian()
    Lambda_JJ, V_JJ = utils.ordered_eig(JJ)
    exp['Lambda'] = Lambda_JJ
    exp['V'] = V_JJ

print('Jacobians done in {} minnutes'.format((time.time()-start_time)/60))

# Fit models
err = np.zeros((len(exps), n_signals, epochs, G.N))
err_wrt_n = np.zeros((len(exps), n_signals, epochs, G.N))
start_time = time.time()
for i in range(n_signals):
    for j, exp in enumerate(exps):
        x = ds.bandlimited_signal(exp['Lambda'], exp['V'], params)
        x_n = ds.GraphSignal.add_noise(x, n_p)
        model = Model(copy.deepcopy(exp['dec']), epochs = epochs, learning_rate=lr)
        err_wrt_n[j, i, :, :], err[j, i, :, :], _ = model.fit(x_n, x, reduce_err=False)

    print('Signal', i, 'done')

print('--- {} minutes ---'.format((time.time()-start_time)/60))

Filter coefs: [0.29399155 0.38311672 0.32289173]
Jacobians done in 0.0015957276026407877 minnutes
Signal 0 done
Signal 1 done
Signal 2 done
Signal 3 done
Signal 4 done
Signal 5 done
Signal 6 done
Signal 7 done
Signal 8 done
Signal 9 done
Signal 10 done
Signal 11 done
Signal 12 done
Signal 13 done
Signal 14 done
Signal 15 done
Signal 16 done
Signal 17 done
Signal 18 done
Signal 19 done
Signal 20 done
Signal 21 done
Signal 22 done
Signal 23 done
Signal 24 done
Signal 25 done
Signal 26 done
Signal 27 done
Signal 28 done
Signal 29 done
Signal 30 done
Signal 31 done
Signal 32 done
Signal 33 done
Signal 34 done
Signal 35 done
Signal 36 done
Signal 37 done
Signal 38 done
Signal 39 done
Signal 40 done
Signal 41 done
Signal 42 done
Signal 43 done
Signal 44 done
Signal 45 done
Signal 46 done
Signal 47 done
Signal 48 done
Signal 49 done
Signal 50 done
Signal 51 done
Signal 52 done
Signal 53 done
Signal 54 done
Signal 55 done
Signal 56 done
Signal 57 done
Signal 58 done
Signal 59 done
Signal 60 do

In [None]:
# Plots Median Error
med_mse = np.median(np.sum(err, axis=3), axis=1)
plt.figure()
plt.semilogy(med_mse.T)
plt.grid(True, which='both')
plt.legend(leg)
plt.title('Median MSE wrt original')
plt.tight_layout()

med_mse_n = np.median(np.sum(err_wrt_n, axis=3), axis=1)
plt.figure()
plt.semilogy(med_mse_n.T)
plt.grid(True, which='both')
plt.legend(leg)
plt.title('Median MSE wrt noise')
plt.tight_layout()

# Plot Errors on JJ eig
ind = np.array([0, 1, 2, 29, -2, -1])

## Using 64 fts
V_ind = exps[2]['V'][:, ind]
proj_err = np.median(err[2, :, :, :].dot(V_ind), axis=0)
plt.figure()
plt.semilogy(np.abs(proj_err))
plt.legend(ind)
plt.grid(True, which='both')
plt.title('Median Proj Err (64fts) wrt original')
plt.tight_layout()

proj_err_n = np.median(err_wrt_n[2, :, :, :].dot(V_ind), axis=0)
plt.figure()
plt.semilogy(np.abs(proj_err_n))
plt.legend(ind)
plt.grid(True, which='both')
plt.title('Median Proj Err (64fts) wrt noisy')
plt.tight_layout()

## Using 5000 fts
V_ind = exps[-1]['V'][:, ind]

proj_err = np.median(err[3, :, :199, :].dot(V_ind), axis=0)
plt.figure()
plt.semilogy(np.abs(proj_err))
plt.legend(ind)
plt.grid(True, which='both')
plt.title('Median Proj Err (100fts) wrt noise')
plt.tight_layout()

proj_err_n = np.median(err_wrt_n[-1, :, :199, :].dot(V_ind), axis=0)
plt.figure()
plt.semilogy(np.abs(proj_err_n))
plt.legend(ind)
plt.grid(True, which='both')
plt.title('Median Proj Err (5000fts) wrt noise')
plt.tight_layout()

In [None]:
plt.show()