First, let's embed a ring in 3D.

In [1]:
import os
import sys
import matplotlib.pyplot as plt
import multiprocessing as mp
import numpy as np
from scipy import integrate
import torch

sys.path.append(os.path.join(os.getenv("HOME"), "RNN_Manifold/"))
import s1_direct_product_decoder, s1_direct_product_generator, geometry_util

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [3]:
encoder, decoder = s1_direct_product_generator.train(1, 3, device, n_training_iterations=3000)

angles = np.arange(start=0, stop=2 * np.pi, step=0.01)
with torch.no_grad():
    points = geometry_util.torch_angles_to_ring(torch.tensor(angles, dtype=torch.get_default_dtype()).to(device))
    points = torch.unsqueeze(points, -2)
    test_embedding = encoder(points)
test_embedding = test_embedding.cpu().numpy()




iteration: 0, decoding loss: 1.061689019203186, distance cost: 0.01981627382338047
iteration: 3, decoding loss: 0.6815962195396423, distance cost: 0.009346794337034225
iteration: 4, decoding loss: 0.34227967262268066, distance cost: 0.03230316936969757
iteration: 5, decoding loss: 0.19943171739578247, distance cost: 0.06746604293584824
iteration: 13, decoding loss: 0.25694727897644043, distance cost: 0.007828565314412117
iteration: 14, decoding loss: 0.1944466382265091, distance cost: 0.007239856291562319
iteration: 15, decoding loss: 0.13218756020069122, distance cost: 0.009230935014784336
iteration: 16, decoding loss: 0.08851396292448044, distance cost: 0.011823172681033611
iteration: 17, decoding loss: 0.060444895178079605, distance cost: 0.014801518060266972
iteration: 18, decoding loss: 0.047143951058387756, distance cost: 0.01845703087747097
iteration: 165, decoding loss: 0.04258669540286064, distance cost: 0.01971469260752201
iteration: 166, decoding loss: 0.0382109135389328, di

The shape it generates should be pretty much random, as there is nothing in the loss function we use that encourages a specific shape

In [4]:
%matplotlib tk
proj_fig = plt.figure()
proj_axs = proj_fig.add_subplot(projection="3d")
proj_axs.scatter(test_embedding[:, 0], test_embedding[:, 1], test_embedding[:, 2], cmap="hsv", c=angles)


<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x7fb8fc2bc3a0>

Now let's generate a higher dimensional ring

In [5]:
encoder, decoder = s1_direct_product_generator.train(1, 24, device, n_training_iterations=3000)

angles = np.arange(start=0, stop=2 * np.pi, step=0.01)
with torch.no_grad():
    points = geometry_util.torch_angles_to_ring(torch.tensor(angles, dtype=torch.get_default_dtype()).to(device))
    points = torch.unsqueeze(points, -2)
    high_d_generated_ring_data = encoder(points)
high_d_generated_ring_data = high_d_generated_ring_data.cpu().numpy()


iteration: 0, decoding loss: 0.9128285050392151, distance cost: 0.002092661103233695
iteration: 1, decoding loss: 0.6880248785018921, distance cost: 0.04595998302102089
iteration: 2, decoding loss: 0.4236714541912079, distance cost: 0.06351182609796524
iteration: 18, decoding loss: 0.3824080526828766, distance cost: 0.04834365472197533
iteration: 19, decoding loss: 0.33976036310195923, distance cost: 0.04819042608141899
iteration: 48, decoding loss: 0.3084816336631775, distance cost: 0.07775271683931351
iteration: 49, decoding loss: 0.267177939414978, distance cost: 0.07470875233411789
iteration: 50, decoding loss: 0.22744251787662506, distance cost: 0.06605216860771179
iteration: 51, decoding loss: 0.20574626326560974, distance cost: 0.0672905370593071
iteration: 52, decoding loss: 0.18668073415756226, distance cost: 0.07650379091501236
iteration: 53, decoding loss: 0.18195901811122894, distance cost: 0.06642839312553406
iteration: 63, decoding loss: 0.20828230679035187, distance cost

KeyboardInterrupt: 

And then decode it

In [None]:
high_d_generated_ring_data = high_d_generated_ring_data/np.mean(np.abs(high_d_generated_ring_data))

In [None]:
encoder, decoder = s1_direct_product_decoder.train(data=high_d_generated_ring_data, manifold_dim=1, device=device,
                                                   n_training_iterations=3000, decoder_weight=10, order_red_weight=0.1)


In [None]:
with torch.no_grad():
    decoded_points, decoded_angles = decoder(torch.tensor(high_d_generated_ring_data, dtype=torch.get_default_dtype()).to(device))

predicted_phases = torch.squeeze(decoded_angles).cpu().numpy()

In [None]:
def reference_phases(phases):
    phases_refd = phases - phases[0]
    phases_refd = np.arctan2(np.sin(phases_refd), np.cos(phases_refd))
    return phases_refd * np.sign(phases_refd[1])

In [None]:
def compare_to_ground_truth(predicted_phases, ground_truth_phases, plot_ax):
    refd_test_phases = reference_phases(predicted_phases)
    refd_true_phases = reference_phases(ground_truth_phases)
    line = np.arange(start=-np.pi, stop=np.pi, step=0.01)
    plot_ax.scatter(refd_true_phases, refd_test_phases)
    plot_ax.plot(line, line, color="black", linestyle="--", label="y=x")
    plot_ax.set_xlabel("True Phase")
    plot_ax.set_ylabel("Found Phase")
    return refd_test_phases, refd_true_phases


In [None]:
%matplotlib inline
fig, ax = plt.subplots()
refd_predicted_phases, refd_true_phases = compare_to_ground_truth(predicted_phases, angles, ax)


Now let's do a harder problem. Let's decode noisy data generated by a dynamic system. First, the dynamics. We will use a typical ring attractor model,

In [None]:
def conv_circ(signal, ker):
    return np.fft.ifft(np.einsum("ij, j -> ij", np.fft.fft(signal, axis=1), np.fft.fft(ker)), axis=1)


def cosine_kernel(w_0, w_1, N):
    step = 2/(N)
    grid = np.arange(start=0, stop=2, step=step) * np.pi
    weights = -w_0 + w_1 * np.cos(grid)
    return weights


def ring_attractor_dynamics(state, kernel, bias_vec, nonlin_fn):
    if len(np.shape(state)) == 1:
        state = np.expand_dims(state, -1)
    state = np.transpose(state)
    return np.transpose(-state + nonlin_fn(conv_circ(state, kernel) + bias_vec))


def sigmoid(x):
    return np.exp(x)/(np.exp(x) + 1)

Generate some samples of equalibrium states. We will generate around 50, to make things interesting.

In [None]:
N = 2 ** 7
n_samples = 50

init_conds = np.random.uniform(-1, 1, (n_samples, N)).astype(np.complex)

w_0 = 1
w_1 = 1
kernel = cosine_kernel(w_0, w_1, N)


def run_equilibriation(init_conds):
    soln = integrate.solve_ivp(lambda _, y: ring_attractor_dynamics(y, kernel, np.zeros(N), sigmoid),
                               [0, 20], init_conds, vectorized=True)
    return soln.y[:, -1]


p = mp.Pool()
solns = np.array(p.map(run_equilibriation, init_conds))
p.close()
p.join()

Add some noise

In [None]:
ring_attractor_data = np.real(solns/np.mean(np.abs(solns)))
noisy_ring_attractor_data = ring_attractor_data + np.random.normal(0, np.max(ring_attractor_data)/15, np.shape(ring_attractor_data))

Note down the ground truth phases

In [None]:
true_ring_att_phases = (np.argmax(ring_attractor_data, axis=1)/N) * 2 * np.pi

In [None]:
fig, ax = plt.subplots()
ax.imshow(noisy_ring_attractor_data, aspect="auto")
ax.set_title("Raw activation data")
ax.set_xlabel("Neuron Index")
ax.set_ylabel("Datapoint")


Now decode

In [None]:
encoder, decoder = s1_direct_product_decoder.train(data=noisy_ring_attractor_data, manifold_dim=1, device=device,
                                                   n_training_iterations=3000, decoder_weight=10, order_red_weight=0.1)


In [None]:
with torch.no_grad():
    decoded_points, decoded_angles = decoder(torch.tensor(noisy_ring_attractor_data, dtype=torch.get_default_dtype()).to(device))

ring_att_predicted_phases = torch.squeeze(decoded_angles).cpu().numpy()

In [None]:
%matplotlib inline
fig, ax = plt.subplots()
refd_ring_att_predicted_phases, refd_ring_att_true_phases = compare_to_ground_truth(ring_att_predicted_phases, true_ring_att_phases, ax)
