# SSP Tactile Encoding
This demo illuistrates the memorization and replay of spatial tactile perception.

In [2]:
%matplotlib widget

from functools import partial

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import nengo
import nengo_spa as spa
import numpy as np
import seaborn as sns
from IPython.display import HTML, display
from nengo_extras.plot_spikes import plot_spikes, preprocess_spikes
from skimage.draw import disk, rectangle

from ssp.dynamics import Trajectory
from ssp.maps import Spatial2D
from ssp.plots import heatmap_animation

2023-10-12 17:00:22.439583: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-10-12 17:00:22.473547: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-10-12 17:00:22.474345: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-10-12 17:00:26.919478: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:266] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


In [19]:
class StimulusGenerator2D:
    """Generates a 2D field with stimuli of a given profile."""

    def __init__(self, dim=(256,), sample_dim=(20,), profile=("ring", 150, 10)):
        if len(dim) < 2:
            dim = (dim[0], dim[0])
        if len(sample_dim) < 2:
            sample_dim = (sample_dim[0], sample_dim[0])
        self.dim = np.asarray(dim)
        self.sample_dim = np.asarray(sample_dim)
        self.__field = np.zeros((dim[0], dim[1], 3))

        match profile:
            case tuple():
                match profile[0]:
                    case "ring":
                        r0, c0 = disk((dim[1] // 2, dim[0] // 2), np.min(dim) // 3)
                        self.__field[r0, c0] = (1, 1, 1)
                        r1, c1 = disk(
                            (dim[1] // 2, dim[0] // 2), np.min(dim) // 3 - profile[1]
                        )
                        self.__field[r1, c1] = (0, 0, 0)
                    case "square":
                        start = self.dim[:2] // 2 - profile[1] // 2
                        end = start + profile[1]
                        r0, c0 = rectangle(start, end)
                        self.__field[r0, c0] = (1, 1, 1)
                        r1, c1 = rectangle(start + profile[2], end - profile[2])
                        self.__field[r1, c1] = (0, 0, 0)
            case _:
                pass

        self.annotation = {"thickness": 2}

    @property
    def field(self):
        return self.__field

    def sample(self, locations):
        """Samples the stimulus field at the given locations.

        Args:
            locations (array like): Array of locations to sample from.

        Returns:
            tuple: Tuple of sampled stimulus and annotated field.
        """
        locations = np.asarray(locations)
        end = locations + self.sample_dim

        sampled = []
        annotated = []
        thickness = self.annotation["thickness"]
        if not isinstance(locations, (list, np.ndarray)):
            locations = [locations]
        for i, loc in enumerate(locations):
            # Plots annotation rectangle
            ann = np.copy(self.field)
            r0, c0 = rectangle(
                np.flip(loc), end=np.flip(end[i]), shape=self.field.shape
            )
            ann[r0, c0] = (1, 0, 0)
            r1, c1 = rectangle(
                np.flip(loc) + thickness,
                end=np.flip(end[i]) - thickness,
                shape=self.field.shape,
            )
            ann[r1, c1] = self.field[r1, c1]
            sampled.append(np.copy(self.field[loc[1] : end[i][1], loc[0] : end[i][0]]))
            annotated.append(ann)
        return sampled, annotated


def create_animation(
    figures,
    titles=None,
    interval=100,
    cmap=sns.diverging_palette(220, 20, sep=20, as_cmap=True),
):
    """Auxiliary function to create an animation of lists of figures.

    Args:
        samples (array like): Samples of the stimulus.
        annotations (array like): Annotated field of the stimulus.
        interval (int, optional): Delay between frames in milliseconds.. Defaults to 100.
        cmap (palette, optional): Seaborn palette. Defaults sns.diverging_palette(220, 20, sep=20, as_cmap=True).

    Returns:
        _type_: _description_
    """
    if not isinstance(figures, list | np.ndarray):
        figures = np.asarray([figures])
        
    fig, axes = plt.subplots(1, len(figures), figsize=(12, 6))
    if len(figures) == 1:
        axes = [axes]
    n_steps = len(figures[0])
    images = []
    for step in range(n_steps):
        frame = []
        for i, figure in enumerate(figures):
            frame.append(axes[i].imshow(figure[step], cmap=cmap, animated=True))

        images.append(frame)
    if titles is not None:
        for i in range(len(axes)):
            axes[i].set_title(titles[i])
    ani = animation.ArtistAnimation(fig, images, interval=interval, blit=True)
    plt.close()
    return ani

In [14]:
field_dim = 150
sample_dim = 10

SG2D = StimulusGenerator2D(profile=("square", field_dim, sample_dim))
t = np.linspace(0, 2 * np.pi, 50)
r = 80
xs, ys = 120 + r * np.cos(t), 120 + r * np.sin(t)
locations = list(zip(xs.astype(int), ys.astype(int)))
sampled, annotated = SG2D.sample(locations)

ani = create_animation([sampled, annotated], titles=["ROI", "Field"])
display(HTML(ani.to_jshtml()))

In [16]:
# Adaoted from ssp_grid_cell_utils.py and ssp_grid_cell_examples.ipynb
# Accurate representation for spatial cognition using grid cells
# Nicole Sandra-Yaffa Dumont & Chris Eliasmith


def ssp_plane_basis(K):
    # Create the bases vectors X,Y as described in the paper with the wavevectors
    # (k_i = (u_i,v_i)) given in a matrix K. To get hexganal patterns use 3 K vectors 120 degs apart
    # To get mulit-scales/orientation, give many such sets of 3 K vectors
    # K is _ by 2
    d = K.shape[0]
    FX = np.ones((d * 2 + 1,), dtype="complex")
    FX[0:d] = np.exp(1.0j * K[:, 0])
    FX[-d:] = np.flip(np.conj(FX[0:d]))
    FX = np.fft.ifftshift(FX)
    FY = np.ones((d * 2 + 1,), dtype="complex")
    FY[0:d] = np.exp(1.0j * K[:, 1])
    FY[-d:] = np.flip(np.conj(FY[0:d]))
    FY = np.fft.ifftshift(FY)

    X = spa.SemanticPointer(data=np.fft.ifft(FX).real)
    Y = spa.SemanticPointer(data=np.fft.ifft(FY).real)
    return X, Y


def generate_grid_cell_basis(n_scales, n_rotates, scale_min=0.5, scale_max=1.8):
    """Generates basis vectors with ``d = n_scales * n_rotates * 6 + 1``."""
    K_hex = np.array([[0, 1], [np.sqrt(3) / 2, -0.5], [-np.sqrt(3) / 2, -0.5]])

    # Combining multiple n_scales sets of 3 wave vectors that give hexagonal grid interference patterns
    # each set of 3 giving a different grid resolution
    scales = np.linspace(scale_min, scale_max, n_scales)
    K_scales = np.vstack([K_hex * i for i in scales])

    # Combining multiple n_rotates sets of 3 wave vectors that give hexagonal grid interference patterns
    # each set of 3 giving a different grid orientation
    thetas = np.arange(0, n_rotates) * np.pi / (3 * n_rotates)
    R_mats = np.stack(
        [
            np.stack([np.cos(thetas), -np.sin(thetas)], axis=1),
            np.stack([np.sin(thetas), np.cos(thetas)], axis=1),
        ],
        axis=1,
    )
    # TODO: don't double transpose
    K_rotates = (R_mats @ K_hex.T).transpose(1, 2, 0).T.reshape(-1, 2)

    # Multiple resolutions and orientations
    # TODO: don't double transpose
    K_scale_rotates = (R_mats @ K_scales.T).transpose(1, 2, 0).T.reshape(-1, 2)

    # Generate the (X, Y) basis vectors
    X, Y = ssp_plane_basis(K_scale_rotates)
    d = n_scales * n_rotates * 6 + 1
    assert len(X) == len(Y) == d
    return X, Y, d

In [24]:
ssp_radius = np.sqrt(2)  # open problem: deriving this
grid_size = 15  # in units of circle's diameter
ssp_scale = ssp_radius * grid_size

X, Y, d = generate_grid_cell_basis(
    n_scales=10, n_rotates=9, scale_min=0.9, scale_max=3.5
)
dim = d
# print("Dimensionality:", d)

ssp_map = Spatial2D(
    dim=dim, scale=ssp_scale, X=X, Y=Y, rng=np.random.RandomState(seed=0)
)
ssp_map.build_grid(
    x_len=field_dim, y_len=field_dim, x_spaces=field_dim + 1, y_spaces=field_dim + 1
)

In [25]:
class Object:
    def __init__(self, name, xyz):
        self.name = name
        self.x = xyz[0]
        self.y = xyz[1]
        if len(xyz) >= 3:
            self.z = xyz[2]
        
ssp_map.update_from_objs([])
mems = []
for i, loc in enumerate(locations):
    loc = np.asarray(loc) / 10.0
    obj = Object(f"Obj_{i}", loc)
    ssp_map.add_object(obj)
    mems.append(ssp_map.heatmap_scores)

ani = create_animation([mems], titles=["Memory"])
display(HTML(ani.to_jshtml()))

In [None]:
stimuli = np.sum(sampled, axis=3)
stimuli = (stimuli - stimuli.min()) / (stimuli.max() - stimuli.min())
input_dim = stimuli.shape[1] * stimuli.shape[2]
n_samples = stimuli.shape[0]
duration = 0.2  # in seconds
time = n_samples * duration  # in seconds

with nengo.Network() as model:
    inp = nengo.Node(
        output=lambda t: stimuli[int(t / duration) % stimuli.shape[0]].flatten()
    )
    ens = nengo.Ensemble(n_neurons=400, dimensions=36)
    nengo.Connection(inp, ens.neurons)
    p = nengo.Probe(ens, synapse=0.01)
    p_spikes = nengo.Probe(ens.neurons)

with nengo.Simulator(model, dt=1e-3) as sim:
    sim.run(time)

dig, axs = plt.subplots(2, 1, figsize=(12, 7), sharex=True)
axs[0].plot(sim.trange(), sim.data[p])
plot_spikes(*preprocess_spikes(sim.trange(), sim.data[p_spikes]), ax=axs[1])
axs[1].set_ylabel("Neuron number")
plt.xlabel("Time [s]")
plt.tight_layout()
plt.show()