In [1]:
%reset -f
%load_ext autoreload
%autoreload 2
%matplotlib ipympl

In [2]:
import numpy as np
import numpy.random as npr
np.set_printoptions(linewidth=200)
from matplotlib import pyplot as plt
import scipy as sp
from scipy import stats

from time import time

from myutils import *

In [3]:
import tensorflow as tf
import tensorflow_probability as tfp
from ctrnn import CTRNNCell, PlotEachBatch

# tf.keras.backend.set_floatx('float32')

ImportError: cannot import name 'naming' from 'tensorflow.python.autograph.core' (/Users/sean/miniconda3/envs/rocky/lib/python3.8/site-packages/tensorflow/python/autograph/core/__init__.py)

In [4]:
tf.__version__

'2.4.0-dev20200722'

In [5]:
def _real_with_test(x, name='x', tol=1e-6):
    imag_norm = tf.linalg.norm(tf.reshape(tf.math.imag(x), -1))
    tot_norm = tf.cast(tf.linalg.norm(tf.reshape(x, -1)), tf.float64)
    if imag_norm / tot_norm > tol:
        tf.print(f'||{name}.imag||/||{name}|| = {imag_norm / tot_norm}')
    return tf.math.real(x)

In [6]:
N = 25
act = tf.tanh
g = 1.5
dt = 0.001
tau = 0.01
sigma = 0.001
alpha = 1.0
classes = 2

dur = 1.0
pos1 = 0.35
pos2 = 0.65
width = 0.1

In [7]:
dt_tau = dt / tau

T = round(dur / dt)
t = np.arange(T)

f = np.stack([
    np.exp(-0.5 * (t - pos1 * T)**2 / (width * T)**2) + np.exp(-0.5 * (t - pos2 * T)**2 / (width * T)**2),
    1 - np.abs((t * dt * 4 - 1) % 4 - 2)
    ])
assert len(f) == classes, f'{classes} target functions needed, but only {len(f)} given'

In [8]:
inputs = []
targets = []
for i in range(500):
    inputs.append(np.ones(T, dtype=int) * (i % classes))
    targets.append(f[i % classes])
inputs = get_one_hot(np.stack(inputs), classes)
targets = np.stack(targets)[..., None]

In [9]:
plt.figure()
a = plt.axes()
for y in f:
    a.plot(t * dt, y)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [10]:
npr.seed(2)
A = npr.randn(N, N) / np.sqrt(N)
d, v = np.linalg.eig(A)
plt.figure()
theta = np.linspace(0, 2 * np.pi, 500)
plt.plot(np.cos(theta), np.sin(theta), 'k')
plt.plot(d.real, d.imag, '.')
plt.axis('equal')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

(-1.0999791907454364,
 1.099999009083116,
 -1.3312458901902486,
 1.3312458901902486)

In [11]:
def loss_f(U, V, A, w, tau, beta, eigf=tf.linalg.eig):
    N = A.shape[0]
    J = A + U @ V
    d, r = eigf(J)
    L = tf.linalg.inv(tf.linalg.adjoint(r) @ r)
    g = (d - 1) / tau
    G1 = -1 / (g[:, None] + tf.math.conj(g))
    G2 = (g[:, None] * tf.math.conj(g)) * G1
    wR = tf.linalg.matvec(r, tf.cast(w, tf.complex128), transpose_a=True) 
    return tf.math.real(tf.reduce_sum((r @ (L * G1)) * tf.math.conj(r)) / N + beta * tf.reduce_sum(wR * tf.linalg.matvec(L * G2, tf.math.conj(wR))))

def _reciprocal(x, ep=1e-20):
    return x / (x * x + ep)

@tf.custom_gradient
def myeig(A):
    e, v = tf.linalg.eig(A)
    def grad(grad_e, grad_v):
        f = _reciprocal(e[..., None, :] - e[..., None])
        f = tf.linalg.set_diag(f, tf.zeros_like(e))
        f = tf.math.conj(f)
        vt = tf.linalg.adjoint(v)
        vgv = vt @ grad_v
        diag_grad_part = tf.cast(tf.math.real(tf.linalg.diag_part(vgv)), vgv.dtype)
        mid = tf.linalg.diag(grad_e) + f * (vgv - vt @ (v * diag_grad_part[..., None, :]))
        grad_a = tf.linalg.solve(vt, mid @ vt)
        return tf.cast(grad_a, A.dtype)
    return (e, v), grad

In [12]:
N = 25
reps = 10
npr.seed(0)

Us = [tf.constant(npr.randn(n, (n * 2) // 5)) for n in np.linspace(N, N * reps, reps, dtype=int)]
Vs = [tf.constant(npr.randn((n * 2) // 5, n)) for n in np.linspace(N, N * reps, reps, dtype=int)]
As = [tf.constant(npr.randn(n, n) / np.sqrt(n)) for n in np.linspace(N, N * reps, reps, dtype=int)]
ws = [tf.constant(npr.randn(n) / np.sqrt(n)) for n in np.linspace(N, N * reps, reps, dtype=int)]
tau = 1
beta = 0.5

curtime = time()
systapes = []
syslosses = []
for i in range(reps):    
    with tf.GradientTape(persistent=True) as tape:
        tape.watch([Us[i], Vs[i]])
        syslosses.append(loss_f(Us[i], Vs[i], As[i], ws[i], tau, beta))
    systapes.append(tape)
print(time() - curtime)

curtime = time()
mytapes = []
mylosses = []
for i in range(reps):    
    with tf.GradientTape(persistent=True) as tape:
        tape.watch([Us[i], Vs[i]])
        mylosses.append(loss_f(Us[i], Vs[i], As[i], ws[i], tau, beta, eigf=myeig))
    mytapes.append(tape)
print(time() - curtime)

curtime = time()
sysgrads = []
for tape, loss, U, V in zip(systapes, syslosses, Us, Vs):
    sysgrads.append(tape.gradient(loss, [U, V]))
print(time() - curtime)

curtime = time()
mygrads = []
for tape, loss, U, V in zip(mytapes, mylosses, Us, Vs):
    mygrads.append(tape.gradient(loss, [U, V]))
print(time() - curtime)

for mygrad, sysgrad in zip(mygrads, sysgrads):
    print([(np.allclose(a, b), isclose(a, b, 1e-6)) for a, b in zip(mygrad, sysgrad)])

2.0241079330444336
2.0073421001434326
0.16692900657653809
0.16349577903747559
[(True, (True, 2.718997071556545e-15)), (True, (True, 1.256914119116992e-15))]
[(True, (True, 2.250345406733898e-15)), (True, (True, 1.3108167355767136e-15))]
[(True, (True, 7.521378025613445e-15)), (True, (True, 4.970108736850761e-15))]
[(True, (True, 4.6834545820503055e-15)), (True, (True, 2.5860783855761247e-15))]
[(True, (True, 4.0001986555501494e-14)), (True, (True, 2.2451181989856864e-14))]
[(True, (True, 3.672746717114926e-13)), (True, (True, 1.2874293979415579e-13))]
[(True, (True, 1.9123238042331307e-14)), (True, (True, 9.255436734099315e-15))]
[(True, (True, 3.248692434161891e-14)), (True, (True, 1.5959244276483033e-14))]
[(True, (True, 1.4527505343602214e-13)), (True, (True, 6.578688873667015e-14))]
[(True, (True, 1.474282845272314e-14)), (True, (True, 9.268780660497847e-15))]


In [None]:
A = tf.random.normal((2, 3, 10, 10), dtype=tf.float64)
d, r = tf.eig(A)

b_dims = len(d.shape) - 1
idx = tf.argsort(tf.math.real(d) + tf.math.imag(d), axis=-1)
d_ = tf.gather(d, idx, batch_dims=b_dims)
r_ = tf.gather(r, idx, batch_dims=b_dims)

r_[0, 0], tf.stack([r[0, 0, i] for i in idx[0, 0]], axis=0)

In [None]:
a = tf.constant(np.arange(6).reshape(2, 3))
tf.gather_nd(a, [[[0, 0], [0, 2], [0, 1]], [[1, 0], [1, 2], [1, 1]]]), array_ops.matrix_transpose(array_ops.gather(array_ops.matrix_transpose(a), [0, 2, 1]))

In [None]:
np.array([[[0, 0], [0, 2], [0, 1]], [[1, 0], [1, 2], [1, 1]]])

In [None]:
np.allclose(tf.cast(A[0, 0], tf.complex128), r_[0, 0] @ tf.linalg.diag(d_[0, 0]) @ tf.linalg.inv(r_[0, 0]))

In [None]:
from tensorflow.python.ops import array_ops

In [None]:
a = tf.complex(npr.randn(3, 3), npr.randn(3, 3))
a, a + 1.5

In [53]:
r = np.array(([1., 1.], [1., 2.]))
r = r / np.linalg.norm(r, axis=0)
d = np.array([2., -3.])
a = np.linalg.solve(r.T, (r * d).T).T
# a = a @ a.T
d, r = np.linalg.eig(a)
l = np.linalg.inv(r)
a, a - (r * d) @ l

(array([[ 7., -5.],
        [10., -8.]]),
 array([[-8.8817842e-16,  0.0000000e+00],
        [ 0.0000000e+00,  0.0000000e+00]]))

In [54]:
ep = 1e-6
np.hstack((l, (np.linalg.eig(a + [[0, ep], [ep, 0]])[1] - np.linalg.eig(a)[1]) / ep))

array([[ 2.82842712e+00, -1.41421356e+00,  1.11022302e-10, -1.07331262e-01],
       [-2.23606798e+00,  2.23606798e+00,  0.00000000e+00,  5.36656228e-02]])

In [55]:
r.T @ ((np.linalg.eig(a + [[0, ep], [ep, 0]])[1] - np.linalg.eig(a)[1]) / ep)

array([[ 7.85046229e-11, -3.79473373e-02],
       [ 4.96506831e-11, -7.24899973e-09]])

In [98]:
def myeig(a):
    d, r = np.linalg.eig(a)
    idx = np.argsort(d)
    d = d[idx]
    r = r[:, idx]
    r = r * np.exp(-np.angle(r[0]) * 1j)
    return d, r

In [107]:
N = 5
i, j = 2, 3
npr.seed(0)
a = npr.randn(N, N)
d, r = myeig(a)
l = np.linalg.inv(r)
print(np.allclose(a, (r * d) @ l))
print(np.allclose(r, myeig(a + ep * np.eye(1, 5, i).T * np.eye(1, 5, j))[1]))
l @ ((myeig(a + ep * np.eye(1, 5, i).T * np.eye(1, 5, j))[1] - r) / ep), 0

True
True


(array([[-5.04977876e-01-2.87022106e-01j,  5.85491942e-01-3.14662586e-01j,  2.84134578e-01+3.29456960e-01j,  1.17878199e-02+2.22586186e-02j,  1.13282231e-02+1.79859714e-02j],
        [ 5.85491942e-01+3.14662586e-01j, -5.04977876e-01+2.87022106e-01j,  2.84134578e-01-3.29456960e-01j,  1.13282231e-02-1.79859714e-02j,  1.17878199e-02-2.22586186e-02j],
        [ 1.97208344e-01-4.37940471e-02j,  1.97208344e-01+4.37940471e-02j, -4.78530404e-01+1.63127580e-17j, -1.54219079e-02-7.60103118e-03j, -1.54219079e-02+7.60103118e-03j],
        [-1.33534149e-02+4.73701169e-02j, -1.37565812e-02+3.91905175e-02j, -4.80654734e-02+5.40799402e-02j, -5.59473320e-04+5.56933999e-04j, -4.44972926e-03-1.22506191e-03j],
        [-1.37565812e-02-3.91905175e-02j, -1.33534149e-02-4.73701169e-02j, -4.80654734e-02-5.40799402e-02j, -4.44972926e-03+1.22506191e-03j, -5.59473320e-04-5.56933999e-04j]]),
 0)

In [109]:
f = 1 / (d - d[:, None] + np.diag(np.infty * np.ones(5)))


array([[ 0.        +0.j        ,  0.        -1.30060987j,  0.68571721-0.19545018j,  0.27758211+0.21533138j,  0.19769616-0.2209446j ],
       [-0.        +1.30060987j,  0.        +0.j        ,  0.68571721+0.19545018j,  0.19769616+0.2209446j ,  0.27758211-0.21533138j],
       [-0.68571721+0.19545018j, -0.68571721-0.19545018j,  0.        +0.j        ,  0.16848114+0.39842665j,  0.16848114-0.39842665j],
       [-0.27758211-0.21533138j, -0.19769616-0.2209446j , -0.16848114-0.39842665j,  0.        +0.j        ,  0.        -0.23483581j],
       [-0.19769616+0.2209446j , -0.27758211+0.21533138j, -0.16848114+0.39842665j, -0.        +0.23483581j,  0.        +0.j        ]])

In [103]:
l

array([[ 0.21316024+1.65607966e-01j,  0.76863769-1.36931603e+00j, -0.41321134-8.91555897e-01j, -0.36639891+1.12330187e+00j, -0.2466083 +7.64473545e-01j],
       [ 0.21316024-1.65607966e-01j,  0.76863769+1.36931603e+00j, -0.41321134+8.91555897e-01j, -0.36639891-1.12330187e+00j, -0.2466083 -7.64473545e-01j],
       [-0.46075964+2.39335188e-16j, -1.51491026+2.27531001e-16j,  0.54475547-9.24524371e-17j, -1.03250401+7.00104548e-17j,  0.40835738-6.13329824e-17j],
       [ 0.78322958-1.20706658e-01j, -0.05916687+6.79858420e-01j,  0.11574507+2.43233497e-01j,  0.57440695+2.85078463e-01j,  0.04316516+4.68863293e-01j],
       [ 0.78322958+1.20706658e-01j, -0.05916687-6.79858420e-01j,  0.11574507-2.43233497e-01j,  0.57440695-2.85078463e-01j,  0.04316516-4.68863293e-01j]])

In [104]:
l @ r

array([[ 1.00000000e+00-2.22044605e-16j,  0.00000000e+00+5.55111512e-17j, -1.11022302e-16+0.00000000e+00j, -5.55111512e-17+0.00000000e+00j, -2.22044605e-16-1.17961196e-16j],
       [ 1.66533454e-16-2.22044605e-16j,  1.00000000e+00-1.11022302e-16j, -5.55111512e-17-1.11022302e-16j,  0.00000000e+00+0.00000000e+00j,  2.22044605e-16-5.55111512e-17j],
       [-5.55111512e-17+6.93889390e-17j,  2.77555756e-17+1.66533454e-16j,  1.00000000e+00+9.14942329e-17j, -2.77555756e-17+0.00000000e+00j, -5.55111512e-17+2.22044605e-16j],
       [ 4.16333634e-17-2.77555756e-17j,  1.38777878e-17-1.38777878e-16j, -1.04083409e-17+0.00000000e+00j,  1.00000000e+00+0.00000000e+00j,  0.00000000e+00-2.77555756e-17j],
       [ 2.77555756e-17+2.77555756e-17j, -5.55111512e-17-8.32667268e-17j, -6.93889390e-18-5.55111512e-17j,  5.55111512e-17+0.00000000e+00j,  1.00000000e+00-1.38777878e-17j]])