In [1]:
%reset -f
%load_ext autoreload
%autoreload 2
%matplotlib ipympl

In [2]:
import numpy as np
import numpy.random as npr
np.set_printoptions(linewidth=200)
from matplotlib import pyplot as plt
import scipy as sp
from scipy import stats

from time import time

from myutils import *

In [3]:
import tensorflow as tf
import tensorflow_probability as tfp
from ctrnn import CTRNNCell, PlotEachBatch

# tf.keras.backend.set_floatx('float32')

In [4]:
tf.__version__, tfp.__version__

('2.2.0', '0.10.1')

In [5]:
def _real_with_test(x, name='x', tol=1e-6):
    imag_norm = tf.linalg.norm(tf.reshape(tf.math.imag(x), -1))
    tot_norm = tf.cast(tf.linalg.norm(tf.reshape(x, -1)), tf.float64)
    if imag_norm / tot_norm > tol:
        tf.print(f'||{name}.imag||/||{name}|| = {imag_norm / tot_norm}')
    return tf.math.real(x)

In [19]:
# simulation parameters
dt = 0.1
n_seconds = 10.0

# network parameters
tau = 1.0
N = 200
n_frac = 0.2
beta = 0.05
test_runs = 100
test_runs_to_plot = 3

# learning
learning_rate = 0.05
epochs = 4001
print_every = 200

# random seeds
numpy_seed = tf_seed = 1

In [20]:
npr.seed(0)
n = int(round(N * n_frac))
J0 = npr.randn(N, N) / np.sqrt(N)
V = 1.0 * npr.randn(n, N) / np.sqrt(N)
U = 1.0 * npr.randn(N, n) / np.sqrt(n)
w = npr.randn(N) / np.sqrt(N)

Js = [J0, J0[:N // 2, :N // 2], J0[N // 2:, N // 2:]]
Us = [U, U[:N // 2], U[N // 2:]]
Vs = [V, V[:, :N // 2], V[:, N // 2:]]
ws = [w, w[:N // 2], w[N // 2:]]

In [8]:
fig, ax = plt.subplots(1, 3, figsize=[6.4 * 3, 4.8])
theta = np.linspace(0, 2 * np.pi, 500)

for ax_, J_, U_, V_ in zip(ax, Js, Us, Vs):
    d, v = np.linalg.eig(J_)
    d_, v_ = np.linalg.eig(J_ + U_ @ V_)
    ax_.plot(np.cos(theta), np.sin(theta), 'k')
    ax_.plot(d.real, d.imag, '.')
    ax_.plot(d_.real, d_.imag, 'o', fillstyle='none')
    ax_.axis('equal')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
def _reciprocal(x, ep=1e-20):
    return x / (x * x + ep)

@tf.custom_gradient
def myeig(A):
    e, v = tf.linalg.eig(A)
    def grad(grad_e, grad_v):
        f = _reciprocal(e[..., None, :] - e[..., None])
        f = tf.linalg.set_diag(f, tf.zeros_like(e))
        f = tf.math.conj(f)
        vt = tf.linalg.adjoint(v)
        vgv = vt @ grad_v
        mid = tf.linalg.diag(grad_e) + f * (vgv - vt @ (v * tf.linalg.diag_part(vgv)[..., None, :]))
        grad_a = tf.linalg.solve(vt, mid @ vt)
        return tf.cast(grad_a, A.dtype)
    return (e, v), grad

@tf.function
def loss_f(U, V, J0, w, tau, beta, eigf=myeig):
    N = J0.shape[0]
    J = J0 + U @ V
    d, r = eigf(J)
    L = tf.linalg.inv(tf.linalg.adjoint(r) @ r)
    g = (d - 1) / tau
    G1 = -1 / (g[:, None] + tf.math.conj(g))
    G2 = (g[:, None] * tf.math.conj(g)) * G1
    wR = tf.linalg.matvec(r, tf.cast(w, tf.complex128), transpose_a=True) 
    return tf.math.real(tf.reduce_sum((r @ (L * G1)) * tf.math.conj(r)) / N + beta * tf.reduce_sum(wR * tf.linalg.matvec(L * G2, tf.math.conj(wR))))

def forward_backward(param):
    with tf.GradientTape() as tape:
        tape.watch(param)
        U, V = tf.split(param, 2, axis=-1)
        U = U / tf.linalg.norm(U, axis=-2, keepdims=True)
        V = tf.linalg.matrix_transpose(V)
        loss = loss_f(U, V, J0, w, tau, beta, eigf=myeig)
    if not hasattr(forward_backward, 'count') or forward_backward.count % 50 == 0:
        if not hasattr(forward_backward, 'count'):
            forward_backward.count = 0
            forward_backward.time = time()
        print(f'iter {forward_backward.count:3d}, time {time() - forward_backward.time:.2f}, device {loss.device}')
        forward_backward.time = time()
    forward_backward.count += 1
    return loss, tape.gradient(loss, param)

In [11]:
var = [tf.Variable(np.concatenate((Us[i + 1], Vs[i + 1].T), axis=-1)) for i in range(2)]

losses = []
for J_, w_, var_ in zip(Js[1:], ws[1:], var):
    opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    
    losses_ = np.zeros(epochs)
    def stager():
        U, V = tf.split(var_, 2, axis=-1)
        U = U / tf.linalg.norm(U, axis=-2, keepdims=True)
        V = tf.linalg.matrix_transpose(V)
        loss = loss_f(U, V, J_, w_, tau, beta, eigf=myeig)
        if not hasattr(stager, 'count') or stager.count % print_every == 0:
            if not hasattr(stager, 'count'):
                stager.count = 0
                stager.time = time()
            print(f'iter {stager.count:4d}, time {time() - stager.time:5.2f}, loss {loss:8.5f}, device {loss.device}')
            stager.time = time()
        losses_[stager.count] = loss.numpy()
        stager.count += 1
        return loss

    for _ in range(epochs):
        opt.minimize(stager, [var_])

    losses.append(losses_)

iter    0, time  0.00, loss  0.85639, device /job:localhost/replica:0/task:0/device:CPU:0
iter  200, time 12.64, loss  0.27797, device /job:localhost/replica:0/task:0/device:CPU:0
iter  400, time 12.37, loss  0.26842, device /job:localhost/replica:0/task:0/device:CPU:0
iter  600, time 11.60, loss  0.26416, device /job:localhost/replica:0/task:0/device:CPU:0
iter  800, time 11.65, loss  0.26159, device /job:localhost/replica:0/task:0/device:CPU:0
iter 1000, time 11.40, loss  0.25986, device /job:localhost/replica:0/task:0/device:CPU:0
iter 1200, time 11.37, loss  0.25849, device /job:localhost/replica:0/task:0/device:CPU:0
iter 1400, time 11.20, loss  0.25745, device /job:localhost/replica:0/task:0/device:CPU:0
iter 1600, time 11.18, loss  0.25660, device /job:localhost/replica:0/task:0/device:CPU:0
iter 1800, time 11.47, loss  0.25596, device /job:localhost/replica:0/task:0/device:CPU:0
iter 2000, time 11.86, loss  0.25528, device /job:localhost/replica:0/task:0/device:CPU:0
iter 2200,

In [13]:
fig, axes = plt.subplots(1, 2, figsize=[6.4 * 2, 4.8 * 1])
axes = axes.reshape(-1)

axes[0].plot(losses[0])
axes[0].set_ylim(0.1, 1)

axes[1].plot(losses[1])
axes[1].set_ylim(0.1, 1)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

(0.1, 1.0)

In [14]:
Us_new, Vs_new = [], []
for var_ in var:
    U, V = [a.numpy() for a in tf.split(var_, 2, axis=-1)]
    Us_new.append(U / np.linalg.norm(U, axis=-2, keepdims=True))
    Vs_new.append(V.T)
Us_new = [np.concatenate(Us_new, axis=-2)] + Us_new
Vs_new = [np.concatenate(Vs_new, axis=-1)] + Vs_new

In [15]:
norms = []
for J_, U0_, V0_, U_, V_ in zip(Js, Us, Vs, Us_new, Vs_new):
    d0_, v0_ = np.linalg.eig(J_ + U0_ @ V0_)
    d_, v_ = np.linalg.eig(J_ + U_ @ V_)
    c0 = npr.randn(d_.shape[0], test_runs)
    t = np.arange(n_seconds / dt) * dt

    tmp = v0_ @ (np.exp((d0_[:, None] - 1) * t / tau) * np.linalg.solve(v0_, c0).T[..., None])
    assert np.allclose(tmp.imag, 0)
    norms0_ = np.linalg.norm(tmp, axis=-2)

    tmp = v_ @ (np.exp((d_[:, None] - 1) * t / tau) * np.linalg.solve(v_, c0).T[..., None])
    assert np.allclose(tmp.imag, 0)
    norms_ = np.linalg.norm(tmp, axis=-2)
    
    norms.append((norms0_, norms_))

In [16]:
fig, axes = plt.subplots(3, 3, figsize=[6.4 * 3, 4.8 * 3])

theta = np.linspace(0, 2 * np.pi, 500)

for i, axes_ in enumerate(axes[:-1]):
    for ax, J_, U0_, V0_, U_, V_ in zip(axes_, Js, Us, Vs, Us_new, Vs_new):
        d, v = np.linalg.eig(J_)
        d0_, v0_ = np.linalg.eig(J_ + U0_ @ V0_)
        d_, v_ = np.linalg.eig(J_ + U_ @ V_)

        ax.plot(np.cos(theta), np.sin(theta), 'k')
        ax.plot(d.real, d.imag, '.')
        ax.plot(d0_.real, d0_.imag, 'x')
        ax.plot(d_.real, d_.imag, 'o', fillstyle='none')
        ax.axis('equal')
        if i == 1:
            ax.set(xlim=(-1.1, 1.1), ylim=(-1.1, 1.1))

for ax, norms_ in zip(axes[-1], norms):
    ax.plot(t, norms_[0][:3].T, color='tab:blue')
    ax.plot(t, norms_[1][:3].T, color='tab:orange')
# ax.plot(t, norms.mean(axis=0), linewidth=2)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [21]:
# var = tf.Variable(np.concatenate((Us_new[0], Vs_new[0].T), axis=-1))
var = tf.Variable(np.concatenate((Us[0], Vs[0].T), axis=-1))

opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)
losses = np.zeros(epochs)

def stager():
    U, V = tf.split(var, 2, axis=-1)
    U = U / tf.linalg.norm(U, axis=-2, keepdims=True)
    V = tf.linalg.matrix_transpose(V)
    loss = loss_f(U, V, Js[0], ws[0], tau, beta, eigf=myeig)
    if not hasattr(stager, 'count') or stager.count % (print_every // 4) == 0:
        if not hasattr(stager, 'count'):
            stager.count = 0
            stager.time = time()
        print(f'iter {stager.count:4d}, time {time() - stager.time:5.2f}, loss {loss:8.5f}, device {loss.device}')
        stager.time = time()
    losses[stager.count] = loss.numpy()
    stager.count += 1
    return loss

for _ in range(epochs):
    opt.minimize(stager, [var])

iter    0, time  0.00, loss  5.74241, device /job:localhost/replica:0/task:0/device:CPU:0
iter   50, time 29.23, loss  0.94449, device /job:localhost/replica:0/task:0/device:CPU:0
iter  100, time 28.87, loss  0.90660, device /job:localhost/replica:0/task:0/device:CPU:0
iter  150, time 28.94, loss  0.83384, device /job:localhost/replica:0/task:0/device:CPU:0
iter  200, time 28.62, loss  0.79063, device /job:localhost/replica:0/task:0/device:CPU:0
iter  250, time 28.71, loss  0.75856, device /job:localhost/replica:0/task:0/device:CPU:0
iter  300, time 28.92, loss  0.73272, device /job:localhost/replica:0/task:0/device:CPU:0
iter  350, time 28.24, loss  0.71088, device /job:localhost/replica:0/task:0/device:CPU:0
iter  400, time 27.92, loss  0.69184, device /job:localhost/replica:0/task:0/device:CPU:0
iter  450, time 27.23, loss  0.67487, device /job:localhost/replica:0/task:0/device:CPU:0
iter  500, time 27.17, loss  0.65953, device /job:localhost/replica:0/task:0/device:CPU:0
iter  550,

In [23]:
U_end, V_end = [a.numpy() for a in tf.split(var, 2, axis=-1)]
U_end = U_end / np.linalg.norm(U_end, axis=-2, keepdims=True)
V_end = V_end.T

d0_, v0_ = np.linalg.eig(Js[0] + Us[0] @ Vs[0])
d_, v_ = np.linalg.eig(Js[0] + U_end @ V_end)
c0 = npr.randn(d_.shape[0], test_runs)
t = np.arange(n_seconds / dt) * dt

tmp = v0_ @ (np.exp((d0_[:, None] - 1) * t / tau) * np.linalg.solve(v0_, c0).T[..., None])
assert np.allclose(tmp.imag, 0)
norms0_ = np.linalg.norm(tmp, axis=-2)

tmp = v_ @ (np.exp((d_[:, None] - 1) * t / tau) * np.linalg.solve(v_, c0).T[..., None])
assert np.allclose(tmp.imag, 0)
norms_ = np.linalg.norm(tmp, axis=-2)

fig, axes = plt.subplots(2, 2, figsize=[6.4 * 2, 4.8 * 2])

axes[0, 0].plot(losses)
# plt.yscale('log')
axes[0, 0].set_ylim(.3, 2)

axes[0, 1].plot(t, norms0_[:3].T, color='tab:blue')
axes[0, 1].plot(t, norms_[:3].T, color='tab:orange')
axes[0, 1].set_ylim(-2, 22)
# ax.plot(t, norms.mean(axis=0), linewidth=2)

theta = np.linspace(0, 2 * np.pi, 500)
d, v = np.linalg.eig(Js[0])
for i, ax in enumerate(axes[1]):
    ax.plot(np.cos(theta), np.sin(theta), 'k')
    ax.plot(d.real, d.imag, '.')
#     ax.plot(d0_.real, d0_.imag, 'x')
    ax.plot(d_.real, d_.imag, 'o', fillstyle='none')
    ax.axis('equal')
    if i == 1:
        ax.set(xlim=(-1.1, 1.1), ylim=(-1.1, 1.1))



Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
N = 25
reps = 10
npr.seed(0)

Us = [tf.constant(npr.randn(n, (n * 2) // 5)) for n in np.linspace(N, N * reps, reps, dtype=int)]
Vs = [tf.constant(npr.randn((n * 2) // 5, n)) for n in np.linspace(N, N * reps, reps, dtype=int)]
As = [tf.constant(npr.randn(n, n) / np.sqrt(n)) for n in np.linspace(N, N * reps, reps, dtype=int)]
ws = [tf.constant(npr.randn(n) / np.sqrt(n)) for n in np.linspace(N, N * reps, reps, dtype=int)]
tau = 1
beta = 0.5

curtime = time()
systapes = []
syslosses = []
for i in range(reps):    
    with tf.GradientTape(persistent=True) as tape:
        tape.watch([Us[i], Vs[i]])
        syslosses.append(loss_f(Us[i], Vs[i], As[i], ws[i], tau, beta))
    systapes.append(tape)
print(time() - curtime)

curtime = time()
mytapes = []
mylosses = []
for i in range(reps):    
    with tf.GradientTape(persistent=True) as tape:
        tape.watch([Us[i], Vs[i]])
        mylosses.append(loss_f(Us[i], Vs[i], As[i], ws[i], tau, beta, eigf=myeig))
    mytapes.append(tape)
print(time() - curtime)

curtime = time()
sysgrads = []
for tape, loss, U, V in zip(systapes, syslosses, Us, Vs):
    sysgrads.append(tape.gradient(loss, [U, V]))
print(time() - curtime)

curtime = time()
mygrads = []
for tape, loss, U, V in zip(mytapes, mylosses, Us, Vs):
    mygrads.append(tape.gradient(loss, [U, V]))
print(time() - curtime)

for mygrad, sysgrad in zip(mygrads, sysgrads):
    print([(np.allclose(a, b), isclose(a, b, 1e-6)) for a, b in zip(mygrad, sysgrad)])