In [1]:
%reset -f
%load_ext autoreload
%autoreload 2
%matplotlib ipympl

In [2]:
import numpy as np
import numpy.random as npr
np.set_printoptions(linewidth=200)
from matplotlib import pyplot as plt
import scipy as sp
from scipy import stats

from time import time

from myutils import *

In [3]:
import tensorflow as tf
import tensorflow_probability as tfp
from ctrnn import CTRNNCell, PlotEachBatch

# tf.keras.backend.set_floatx('float32')

In [4]:
tf.__version__, tfp.__version__

('2.2.0', '0.10.1')

In [5]:
def _real_with_test(x, name='x', tol=1e-6):
    imag_norm = tf.linalg.norm(tf.reshape(tf.math.imag(x), -1))
    tot_norm = tf.cast(tf.linalg.norm(tf.reshape(x, -1)), tf.float64)
    if imag_norm / tot_norm > tol:
        tf.print(f'||{name}.imag||/||{name}|| = {imag_norm / tot_norm}')
    return tf.math.real(x)

from collections.abc import Iterable

def iterable(obj):
    return isinstance(obj, Iterable)

In [56]:
# simulation parameters
dt = 0.1
n_seconds = 10.0

# network parameters
tau = 1.0
N = 200
n_frac = 0.1
n = int(round(N * n_frac))

# Set the CT and TC weights to the same values as the J weights
# sigma_U = 1 / np.sqrt(N)
# sigma_V = 1 / np.sqrt(N)

# Set the CTC effective weights to the same values as the J weights
sigma_U = 1 / (n_frac * N * N)**0.25
sigma_V = 1 / (n_frac * N * N)**0.25

initial_scale = 0.4
fixed_norms = True

beta = 0.05
test_runs = 100
test_runs_to_plot = 3

# learning
learning_rate = 0.01
epochs = 20001
print_every = 500
split_first = False
ext_penalty = 1e5
dist_min = 1e-2

# random seeds
numpy_seed = tf_seed = 1

In [57]:
npr.seed(0)
J0 = npr.randn(N, N) / np.sqrt(N)
V = initial_scale * npr.randn(n, N) * sigma_V
U = initial_scale * npr.randn(N, n) * sigma_U
w = npr.randn(N) / np.sqrt(N)

Js, Us, Vs, ws = [J0], [U], [V], [w]
if split_first:
    Js = [J0, J0[:N // 2, :N // 2], J0[N // 2:, N // 2:]]
    Us = [U, U[:N // 2], U[N // 2:]]
    Vs = [V, V[:, :N // 2], V[:, N // 2:]]
    ws = [w, w[:N // 2], w[N // 2:]]

In [58]:
fig, ax = plt.subplots(1, len(Js), figsize=[6.4 * len(Js), 4.8])
ax = ax if iterable(ax) else np.array([ax])
theta = np.linspace(0, 2 * np.pi, 500)

for ax_, J_, U_, V_ in zip(ax, Js, Us, Vs):
    d, v = np.linalg.eig(J_)
    d_, v_ = np.linalg.eig(J_ + U_ @ V_)
    ax_.plot(np.cos(theta), np.sin(theta), 'k')
    ax_.plot(d.real, d.imag, '.')
    ax_.plot(d_.real, d_.imag, 'o', fillstyle='none')
    ax_.axis('equal')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [59]:
def _reciprocal(x, ep=1e-20):
    return x / (x * x + ep)

@tf.custom_gradient
def myeig(A):
    e, v = tf.linalg.eig(A)
    def grad(grad_e, grad_v):
        f = _reciprocal(e[..., None, :] - e[..., None])
        f = tf.linalg.set_diag(f, tf.zeros_like(e))
        f = tf.math.conj(f)
        vt = tf.linalg.adjoint(v)
        vgv = vt @ grad_v
        mid = tf.linalg.diag(grad_e) + f * (vgv - vt @ (v * tf.linalg.diag_part(vgv)[..., None, :]))
        grad_a = tf.linalg.solve(vt, mid @ vt)
        return tf.cast(grad_a, A.dtype)
    return (e, v), grad

@tf.function
def loss_f(U, V, J0, w, tau, beta, eigf=myeig):
    N = J0.shape[0]
    J = J0 + U @ V
    d, r = eigf(J)
    L = tf.linalg.inv(tf.linalg.adjoint(r) @ r)
    g = (d - 1) / tau
    g_real = tf.math.real(g)
    G1 = -_reciprocal(g[:, None] + tf.math.conj(g))
    G2 = (g[:, None] * tf.math.conj(g)) * G1
    wR = tf.linalg.matvec(r, tf.cast(w, tf.complex128), transpose_a=True)
    distance_to_margin = tf.linalg.set_diag(tf.abs(d[..., None] - d[..., None, :]) - dist_min, tf.zeros_like(d, dtype=tf.float64))
    return tf.math.real(tf.reduce_sum((r @ (L * G1)) * tf.math.conj(r)) / N + beta * tf.reduce_sum(wR * tf.linalg.matvec(L * G2, tf.math.conj(wR)))) \
           + ext_penalty * tf.reduce_sum(tf.where(g_real < 0, tf.zeros_like(g_real), g_real)**2) \
#            + ext_penalty * tf.reduce_sum(tf.where(distance_to_margin > 0, tf.zeros_like(distance_to_margin), distance_to_margin)**2) / 2

def _get_U_and_V(param):
    U, V = tf.split(param, 2, axis=-1)
    if fixed_norms:
        U = (U / tf.linalg.norm(U, axis=-2, keepdims=True)) * np.sqrt(N) * sigma_U
        V = (V / tf.linalg.norm(V, axis=-2, keepdims=True)) * np.sqrt(N) * sigma_V
    else:
        U = U / tf.linalg.norm(U, axis=-2, keepdims=True)
    V = tf.linalg.matrix_transpose(V)
    return U, V

def forward_backward(param):
    with tf.GradientTape() as tape:
        tape.watch(param)
        U, V = _get_U_and_V(param)
        loss = loss_f(U, V, J0, w, tau, beta, eigf=myeig)
    if not hasattr(forward_backward, 'count') or forward_backward.count % 50 == 0:
        if not hasattr(forward_backward, 'count'):
            forward_backward.count = 0
            forward_backward.time = time()
        print(f'iter {forward_backward.count:3d}, time {time() - forward_backward.time:.2f}, device {loss.device}')
        forward_backward.time = time()
    forward_backward.count += 1
    return loss, tape.gradient(loss, param)

In [60]:
if split_first:
    var = [tf.Variable(np.concatenate((Us[i + 1], Vs[i + 1].T), axis=-1)) for i in range(2)]

    losses = []
    for J_, w_, var_ in zip(Js[1:], ws[1:], var):
        opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)

        losses_ = np.zeros(epochs)
        def stager():
            U, V = _get_U_and_V(var_)
            loss = loss_f(U, V, J_, w_, tau, beta, eigf=myeig)
            if not hasattr(stager, 'count') or stager.count % print_every == 0:
                if not hasattr(stager, 'count'):
                    stager.count = 0
                    stager.time = time()
                print(f'iter {stager.count:4d}, time {time() - stager.time:5.2f}, loss {loss:8.5f}, device {loss.device}')
                stager.time = time()
            losses_[stager.count] = loss.numpy()
            stager.count += 1
            return loss

        for _ in range(epochs):
            opt.minimize(stager, [var_])

        losses.append(losses_)

In [61]:
if split_first:
    fig, axes = plt.subplots(1, 2, figsize=[6.4 * 2, 4.8 * 1])
    axes = axes.reshape(-1)

    axes[0].plot(losses[0])
    axes[0].set_ylim(0.1, 3)

    axes[1].plot(losses[1])
    axes[1].set_ylim(0.1, 3)

In [62]:
if split_first:
    Us_new, Vs_new = [], []
    for var_ in var:
        U, V = [a.numpy() for a in tf.split(var_, 2, axis=-1)]
        Us_new.append(U / np.linalg.norm(U, axis=-2, keepdims=True))
        Vs_new.append(V.T)
    Us_new = [np.concatenate(Us_new, axis=-2)] + Us_new
    Vs_new = [np.concatenate(Vs_new, axis=-1)] + Vs_new
else:
    Us_new, Vs_new = Us, Vs

In [63]:
if split_first:
    norms = []
    for J_, U0_, V0_, U_, V_ in zip(Js, Us, Vs, Us_new, Vs_new):
        d0_, v0_ = np.linalg.eig(J_ + U0_ @ V0_)
        d_, v_ = np.linalg.eig(J_ + U_ @ V_)
        c0 = npr.randn(d_.shape[0], test_runs)
        t = np.arange(n_seconds / dt) * dt

        tmp = v0_ @ (np.exp((d0_[:, None] - 1) * t / tau) * np.linalg.solve(v0_, c0).T[..., None])
        assert np.allclose(tmp.imag, 0)
        norms0_ = np.linalg.norm(tmp, axis=-2)

        tmp = v_ @ (np.exp((d_[:, None] - 1) * t / tau) * np.linalg.solve(v_, c0).T[..., None])
        assert np.allclose(tmp.imag, 0)
        norms_ = np.linalg.norm(tmp, axis=-2)

        norms.append((norms0_, norms_))

In [64]:
if split_first:
    fig, axes = plt.subplots(3, 3, figsize=[6.4 * 3, 4.8 * 3])

    theta = np.linspace(0, 2 * np.pi, 500)

    for i, axes_ in enumerate(axes[:-1]):
        for ax, J_, U0_, V0_, U_, V_ in zip(axes_, Js, Us, Vs, Us_new, Vs_new):
            d, v = np.linalg.eig(J_)
            d0_, v0_ = np.linalg.eig(J_ + U0_ @ V0_)
            d_, v_ = np.linalg.eig(J_ + U_ @ V_)

            ax.plot(np.cos(theta), np.sin(theta), 'k')
            ax.plot(d.real, d.imag, '.')
            ax.plot(d0_.real, d0_.imag, 'x')
            ax.plot(d_.real, d_.imag, 'o', fillstyle='none')
            ax.axis('equal')
            if i == 1:
                ax.set(xlim=(-1.1, 1.1), ylim=(-1.1, 1.1))

    for ax, norms_ in zip(axes[-1], norms):
        ax.plot(t, norms_[0][:3].T, color='tab:blue')
        ax.plot(t, norms_[1][:3].T, color='tab:orange')
    # ax.plot(t, norms.mean(axis=0), linewidth=2)

In [65]:
var = tf.Variable(np.concatenate((Us_new[0], Vs_new[0].T), axis=-1))

opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)
losses = np.zeros(epochs)
loss_best = [np.inf]

var_best = [tf.zeros_like(var)]

def stager():
    U, V = _get_U_and_V(var)
    loss = loss_f(U, V, Js[0], ws[0], tau, beta, eigf=myeig)
    if not hasattr(stager, 'count') or stager.count % (print_every // 4) == 0:
        if not hasattr(stager, 'count'):
            stager.count = 0
            stager.time = time()
        print(f'iter {stager.count:4d}, time {time() - stager.time:5.2f}, loss {loss:8.5f}, device {loss.device}')
        stager.time = time()
    losses[stager.count] = loss.numpy()
    if losses[stager.count] < loss_best[0]:
        loss_best[0] = losses[stager.count]
        var_best[0] = var.numpy()
        
    stager.count += 1
    return loss

for _ in range(epochs):
    opt.minimize(stager, [var])

iter    0, time  0.00, loss 40992.25297, device /job:localhost/replica:0/task:0/device:CPU:0
iter  125, time 75.08, loss  3.24166, device /job:localhost/replica:0/task:0/device:CPU:0
iter  250, time 75.01, loss  2.52192, device /job:localhost/replica:0/task:0/device:CPU:0
iter  375, time 74.10, loss  2.21016, device /job:localhost/replica:0/task:0/device:CPU:0
iter  500, time 70.79, loss  2.02419, device /job:localhost/replica:0/task:0/device:CPU:0
iter  625, time 70.97, loss  1.89637, device /job:localhost/replica:0/task:0/device:CPU:0
iter  750, time 70.30, loss  1.80108, device /job:localhost/replica:0/task:0/device:CPU:0
iter  875, time 69.79, loss  1.72616, device /job:localhost/replica:0/task:0/device:CPU:0
iter 1000, time 69.57, loss  1.66496, device /job:localhost/replica:0/task:0/device:CPU:0
iter 1125, time 71.00, loss  1.61353, device /job:localhost/replica:0/task:0/device:CPU:0
iter 1250, time 69.56, loss  1.56933, device /job:localhost/replica:0/task:0/device:CPU:0
iter 13

In [72]:
plt.figure()
plt.plot(losses)
plt.ylim(0.4, 2)
plt.yscale('log')
loss_best[0], np.nonzero(losses == loss_best[0])

  """Entry point for launching an IPython kernel.


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

(0.6050649143810473, (array([20000]),))

In [68]:
var_best[0], var
loss_f(*_get_U_and_V(var), Js[0], ws[0], tau, beta, eigf=myeig), loss_f(*_get_U_and_V(var_best[0]), Js[0], ws[0], tau, beta, eigf=myeig),

(<tf.Tensor: shape=(), dtype=float64, numpy=0.6050641290637333>,
 <tf.Tensor: shape=(), dtype=float64, numpy=0.6050649143810473>)

In [69]:
U_end, V_end = _get_U_and_V(var)

d0_, v0_ = np.linalg.eig(Js[0] + Us[0] @ Vs[0])
d_, v_ = np.linalg.eig(Js[0] + U_end @ V_end)
c0 = npr.randn(d_.shape[0], test_runs)
t = np.arange(n_seconds / dt) * dt

tmp = v0_ @ (np.exp((d0_[:, None] - 1) * t / tau) * np.linalg.solve(v0_, c0).T[..., None])
assert np.allclose(tmp.imag, 0)
norms0_ = np.linalg.norm(tmp, axis=-2)

tmp = v_ @ (np.exp((d_[:, None] - 1) * t / tau) * np.linalg.solve(v_, c0).T[..., None])
assert np.allclose(tmp.imag, 0)
norms_ = np.linalg.norm(tmp, axis=-2)

fig, axes = plt.subplots(2, 2, figsize=[6.4 * 2, 4.8 * 2])

axes[0, 0].plot(losses)
# plt.yscale('log')
axes[0, 0].set_ylim(.3, 2)

axes[0, 1].plot(t, norms0_[:test_runs_to_plot].T, color='tab:blue')
axes[0, 1].plot(t, norms_[:test_runs_to_plot].T, color='tab:orange')
axes[0, 1].set_ylim(-1.2, 13.2)
# ax.plot(t, norms.mean(axis=0), linewidth=2)

theta = np.linspace(0, 2 * np.pi, 500)
d, v = np.linalg.eig(Js[0])
for i, ax in enumerate(axes[1]):
    ax.plot(np.cos(theta), np.sin(theta), 'k')
    ax.plot(d.real, d.imag, '.')
#     ax.plot(d0_.real, d0_.imag, 'x')
    ax.plot(d_.real, d_.imag, 'o', fillstyle='none')
    ax.axis('equal')
    if i == 1:
        ax.set(xlim=(-1.1, 1.1), ylim=(-1.1, 1.1))



Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [73]:
import seaborn as sns
plt.figure()
plt.ylim(0, 10)

sns.distplot(Js[0].flatten(), hist = False, kde = True,
                 kde_kws = {'linewidth': 2, 'shade': True},
                 label = f'J (std = {Js[0].flatten().std():5.3f})')
sns.distplot(np.concatenate((U_end, tf.transpose(V_end)), -1).flatten(), hist = False, kde = True,
                 kde_kws = {'linewidth': 2, 'shade': True},
                 label = f'v & u (std = {np.concatenate((U_end, tf.transpose(V_end)), -1).flatten().std():5.3f})')
sns.distplot((U_end @ V_end).numpy().flatten(), hist = False, kde = True,
                 kde_kws = {'linewidth': 2, 'shade': True, 'gridsize': 10000},
                 label = f"v'u (std = {(U_end @ V_end).numpy().flatten().std():5.3f})")
plt.xlim(-0.5, 0.5)

  


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

(-0.5, 0.5)

In [161]:
plt.figure()
plt.ylim(0, 10)

sns.distplot(Js[0].flatten(), hist = False, kde = True,
                 kde_kws = {'linewidth': 2, 'shade': True},
                 label = f'J (std = {Js[0].flatten().std():5.3f})')
sns.distplot(np.concatenate((U_end, tf.transpose(V_end)), -1).flatten(), hist = False, kde = True,
                 kde_kws = {'linewidth': 2, 'shade': True},
                 label = f'v & u (std = {np.concatenate((U_end, tf.transpose(V_end)), -1).flatten().std():5.3f})')
sns.distplot((U_end @ V_end).numpy().flatten(), hist = False, kde = True,
                 kde_kws = {'linewidth': 2, 'shade': True, 'gridsize': 10000},
                 label = f"v'u (std = {(U_end @ V_end).numpy().flatten().std():5.3f})")
plt.xlim(-0.5, 0.5)

  """Entry point for launching an IPython kernel.


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

(-0.5, 0.5)

In [22]:
np.sqrt(0.002)

0.044721359549995794

In [25]:
1 / (n_frac * N * N)**0.25, 1 / N**0.5

(0.14953487812212204, 0.1)

In [None]:
title = 12
abstract = 156
figure_captions = 1373
references = 874
intro_and_results = 3565