In [1]:
import matplotlib
import numpy as np
import rescomp as rc
import networkx as nx
from scipy import sparse
import matplotlib.pyplot as plt
import sklearn.gaussian_process as GP
from scipy.interpolate import CubicSpline

import warnings
warnings.filterwarnings('ignore')

In [2]:
seed=52673
np.random.seed(seed)

In [3]:
# plt params
plt.rcParams.update({
    'text.usetex': True,
    "font.family": "serif",
    'font.sans-serif': ['Computer Modern Roman'],
    'font.serif': ['Computer Modern Roman'],
    'text.latex.preamble': r'\usepackage{amsfonts}',
})

# arrow params
arrow_params = dict(
    arrowprops={
        'width': 2,
        'headwidth': 10,
        'headlength': 10,
        'color': 'k',
    },
    xycoords='figure fraction', 
    textcoords='figure fraction',
)

# text params
labelparams = {
        'fontsize': 40,
        'horizontalalignment': 'center',
        #'transform': ax_overlay.transAxes,
    }

tickparams = {
        'fontsize': 30,
        'horizontalalignment': 'right',
        #'transform': ax_overlay.transAxes,
    }


In [4]:
def valid_prediction_index(err, tol):
    """First index i where err[i] > tol. err is assumed to be 1D and tol is a float. If err is never greater than tol, then len(err) is returned."""
    mask = np.logical_or(err > tol, ~np.isfinite(err))
    if np.any(mask):
        return np.argmax(mask)
    return len(err)

def wa_vptime(ts, Uts, pre, vpttol=0.5):
    """
    Valid prediction time for a specific instance.
    """
    err = np.linalg.norm(Uts-pre, axis=1, ord=2)
    idx = valid_prediction_index(err, vpttol)
    if idx == 0:
        vptime = 0.
    else:
        vptime = ts[idx-1] - ts[0]

    return vptime

In [5]:
# GET TRAINING AND TESTING SIGNALS
t, U = rc.orbit('lorenz', duration=1000)
u = CubicSpline(t, U)
tr = t[:4000]
Utr = u(t[:4000])
ts = t[4000:]
Uts = u(t[4000:])

In [6]:
def testit(rho, pthin, a, g, s, norm=True):

    seed=52673
    np.random.seed(seed)

    # OTHER PARAMS
    n = 50
    c = 4

    # CREATE GRAPH
    A = nx.erdos_renyi_graph(n,c*(1-pthin)/(n-1),directed=True)
    A = sparse.dok_matrix(nx.adj_matrix(A).T)

    # MAKE RESERVOIR
    B = A*(rho/np.abs(sparse.linalg.eigs(A.astype(float),k=1)[0][0]))
    res = rc.ResComp(B.tocoo(), spect_rad=rho, sigma=s, gamma=g, ridge_alpha=a)
    res.train(tr, Utr)
                
    # GET VPT
    Upred, _ = res.predict(ts, r0=res.r0, return_states=True)

    #  RES_CONDITION
    r0 = res.initial_condition(Utr[0])
    res_states = res.internal_state_response(tr, Utr, r0=r0)

    # VPT
    vpt = wa_vptime(ts[:-2], Uts[:-2], Upred[2:], vpttol=5)

    return Upred, vpt, res_states, res

In [7]:
def runit(rho, pthin, a, g, s, norm=True):
    # OTHER PARAMS
    n = 500
    c = 4

    # CREATE GRAPH
    A = nx.erdos_renyi_graph(n,c*(1-pthin)/(n-1),directed=True)
    A = sparse.dok_matrix(nx.adj_matrix(A).T)

    # MAKE RESERVOIR
    if norm:
        B = A*(rho/np.max(np.abs(sparse.linalg.eigs(A.astype(float))[0])))
        res = rc.ResComp(B.tocoo(), spect_rad=rho, sigma=s, gamma=g, ridge_alpha=a)
    else:
        res = rc.ResComp(A.tocoo().astype(float), spect_rad=rho, sigma=s, gamma=g, ridge_alpha=a)
    res.train(tr, Utr)
                
    # GET VPT
    Uhat, tr_states = res.predict(ts, r0=res.r0, return_states=True)
    Upred, pred_states = res.predict(ts, r0=res.r0, return_states=True)
    vpt = wa_vptime(ts[:-2], Uts[:-2], Upred[2:], vpttol=5)

    return Upred, vpt, pred_states, Uhat, tr_states

In [8]:
# Generate attractor orbit
t, U = rc.orbit('lorenz', trim=True, duration=30)
tr_t = t[900:1100]
tr_U = U[900:1100]

# Train res
res_sz = 50
rescomp = rc.ResComp(res_sz=res_sz, mean_degree=0.1, map_initial='activ_f', ridge_alpha=0.0001, gamma=5, sigma=0.05)
rescomp.train(tr_t, tr_U)

# Set t, U
r0 = rescomp.initial_condition(tr_U[0])
res_signal = rescomp.internal_state_response(tr_t, tr_U, r0)

# Reconstructed signal
U_hat = res_signal @ rescomp.W_out.T

# Predicted signal
pred_t = np.linspace(51,55.99,500)
pred_res, pred_states = rescomp.predict(np.linspace(tr_t[-1], tr_t[-1]+5, 500), r0=rescomp.r0, return_states=True)

In [9]:
# Figure
fig = plt.figure(figsize=(12,5.5))



# Set up input signal
ax_sig = fig.add_axes([0.0, 0.1, 0.4, 0.7])
ax_sig.set_xticks([])
ax_sig.set_yticks([0])
# ax_sig.set_ylim(-100,105)

# Set up responses
ax_res = fig.add_axes([0.7, 0.1, 0.4, 0.7])
ax_res.set_xticks([])
ax_res.set_yticks([-1,0,1])

# Set up overlay w/text params
ax_overlay = fig.add_axes([0, 0, 1, 1])
ax_overlay.axis([0,1,0,1])
ax_overlay.axis('off')





# Plot input signal
ax_sig.plot(tr_t, tr_U, color=(0.0, 0.3, 0.8))
ax_sig.scatter([tr_t[0]]*3, tr_U[0], color=(0.0, 0.3, 0.8))
ax_sig.scatter([tr_t[-1]]*3, tr_U[-1], color=(0.0, 0.3, 0.8))

ax_sig.axvline(
            tr_t[0], 
            ymin=0, ymax=1,
            color=(0.0, 0.3, 0.8), linestyle='--', alpha=0.8, 
        )

ax_sig.axvline(
            tr_t[-1], 
            ymin=0, ymax=1,
            color=(0.0, 0.3, 0.8), linestyle='--', alpha=0.8, 
        )



ax_sig.spines['top'].set_visible(False)
ax_sig.spines['right'].set_visible(False)
ax_sig.spines['bottom'].set_position('zero')
    
# Plot reservoir states
order = np.argsort(res_signal[0])
for i,c in enumerate(plt.get_cmap('plasma')(np.linspace(0,0.8,res_sz))):
    ax_res.plot(tr_t, res_signal[:, order[i]], '-', color=c, alpha=0.8)
ax_res.plot([tr_t[0]]*res_sz, res_signal[0], 'k.')
ax_res.plot([tr_t[-1]]*res_sz, res_signal[-1], 'k.')

ax_res.spines['top'].set_visible(False)
ax_res.spines['right'].set_visible(False)
ax_res.spines['bottom'].set_position('zero')

    
    
# Plot arrows
ax_overlay.annotate('', xy=(0.68, 0.5), xytext=(0.43, 0.5),
    **arrow_params
)
    
    
    
# Plot labels
ax_overlay.text(0.22, 0.02, r'$\mathbf{u}(t) \in \mathbb{R}^{m}$', **labelparams)
ax_overlay.text(0.57, 0.53, r'$A \in \mathbb{R}^{n \times n}$', **labelparams)
ax_overlay.text(0.92, 0.02, r'$\mathbf{r}(t) \in \mathbb{R}^{n}$', **labelparams)

ax_overlay.text(0.09, 0.87, r'$\mathbf{u}(0)$', **tickparams)
ax_overlay.text(0.44, 0.87, r'$\mathbf{u}(T)$', **tickparams)
ax_overlay.text(0.86, 0.87, r'$\mathbf{r}(0) = \mathbf{r}_0$', **tickparams)
ax_overlay.text(1.13, 0.87, r'$\mathbf{r}(T)$', **tickparams)
    
    
# Show plot
# plt.savefig('wa_fig1.png')
plt.show()

RuntimeError: Failed to process string with tex because latex could not be found

<Figure size 1200x550 with 3 Axes>