### Exploring recovering connectivity from sparse observation data

**What does it take to reverse engineer a worm: Power and friends**

Running perturbation experiments to reverse engineer the input output functions of all neurons in C. elegans has recently been proposed as a strategy towards whole nervous system emulation. However, this leads to difficult inverse problems where causal interactions have to be inferred from observational data. Here we ask how hard it is to reverse engineer such systems using both theoretical and experimental approaches. We find that reverse engineering benefits considerably from stimulation, requires the ability to integrate experimental data across animals, and should require thousands of experiments. Nonetheless, the resulting numbers should allow reverse engineering the neuronal input output functions under relatively moderate assumptions about simplicity of neurons in C. elegans. 

The connectome of C. elegans is known to be roughly the same across genotypically matched animals, yet how that connectivity ultimately determines neural dynamics and behavior across different animals is unknown. We consider the problem of reverse engineering the nervous system of C. elegans as a systems identification problem where multiple partially observed instances of the dynamics of a common system are to be used to reconstruct the true underlying system. We think that this work is important for thinking about causal structure in the brain.

*Last updated: 6 June 2024*

#### Import libraries & modules
---


In [1]:
import torch
import pickle
import random
import numpy as np
import scipy as scp
import seaborn as sns
import networkx as nx
import matplotlib.pyplot as plt

from typing import Union
from omegaconf import OmegaConf
from sklearn.preprocessing import StandardScaler
from preprocess._utils import smooth_data_preprocess, reshape_calcium_data
from utils import NUM_NEURONS, NEURON_LABELS, init_random_seeds

# Initialize the random seeds
init_random_seeds(42)

Loading from /net/vast-storage/scratch/vast/yanglab/qsimeon/worm-graph/data/raw/neuron_master_sheet.csv.

CUDA device found.
	 GPU: NVIDIA A100 80GB PCIe


##### Helper functions
___

In [2]:
def plot_neural_signals(data, time_tensor, neuron_idx=None, yax_limit=True, suptitle=None):
    assert isinstance(data, torch.Tensor), "data must be a PyTorch tensor"
    assert isinstance(time_tensor, torch.Tensor), "time_tensor must be a PyTorch tensor"
    assert data.dim() == 2, "data must be a 2D tensor"
    assert isinstance(neuron_idx, (int, list)), "neuron_idx must be an integer or list"

    time_tensor = time_tensor.squeeze()
    assert data.size(0) == time_tensor.size(0), "Number of rows in data and time_tensor must match"

    num_neurons = data.size(1)

    # Randomly select the column indices if not provided
    if isinstance(neuron_idx, int):
        assert neuron_idx <= num_neurons, "neuron_idx cannot exceed the number of neurons"
        column_indices = np.random.choice(num_neurons, neuron_idx, replace=False)
    elif isinstance(neuron_idx, list):
        assert len(neuron_idx) <= num_neurons, "neuron_idx cannot exceed the number of neurons"
        column_indices = np.array(neuron_idx)

    num_columns = len(column_indices)

    # Extract the selected columns from the data tensor
    selected_columns = data[:, column_indices]

    # Define the color palette using scientific colors
    colors = sns.color_palette("bright", num_columns)

    # Plotting subplots vertically
    fig, axs = plt.subplots(num_columns, 1, figsize=(12, num_columns))
    fig.tight_layout(pad=0.0)

    # If num_columns is 1, make ax iterable by wrapping it in a list
    if num_columns == 1:
        axs = [axs]

    # Now your existing loop should work without modification
    for i, ax in enumerate(axs):
        ax.plot(time_tensor, selected_columns[:, i], color=colors[i])
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["bottom"].set_visible(False)
        ax.spines["left"].set_visible(False)
        ax.yaxis.set_ticks_position("none")
        if yax_limit:
            ax.set_ylim(-1.0, 1.0)
        ax.set_ylabel("{}".format(NEURON_LABELS[column_indices[i]]))

        if i < num_columns - 1:
            ax.set_xticks([])
        else:
            ax.set_xlabel("Time (s)")

    # Add a super title to the figure if provided
    if suptitle is not None:
        fig.suptitle(suptitle, fontsize=16)

    plt.tight_layout(pad=1)

    plt.show()

In [3]:
def plot_3d_trajectory(X, axis_labels=("Time", "Value", "Z Axis"), title="Trajectory", show=True):
    """
    Plot a trajectory from a dataset, which can be 1D, 2D, or 3D.

    Parameters:
    - X: A 2D numpy array containing the trajectory data. Must be shaped as (time, features).
    - axis_labels: A tuple containing the labels for the axes. Default is ('Time', 'Value', 'Z Axis').
    - title: Title of the plot.
    - show: If True, the plot will be displayed. If False, the plot object will be returned.

    Returns:
    - fig, ax: The figure and axis objects of the plot if show is False.
    """
    max_timesteps = X.shape[0]
    dims = X.shape[1] if len(X.shape) > 1 else 1

    # Create a new figure for the plot
    fig = plt.figure(figsize=(12, 12))
    if dims >= 3:
        ax = fig.add_subplot(111, projection="3d")
        plot_dims = 3
    else:
        ax = fig.add_subplot(111)
        plot_dims = dims

    # Create a color map based on the time progression
    norm = plt.Normalize(0, max_timesteps)
    colors = plt.cm.viridis(norm(np.arange(max_timesteps)))

    # Extract coordinates from the data and plot accordingly
    if plot_dims == 1:
        x = np.arange(max_timesteps)
        y = X.flatten()
        z = np.zeros(max_timesteps)
        ax_labels = (axis_labels[0], "Value", "Fixed Z")
        # Plot the 1D trajectory with a color gradient
        for i in range(1, max_timesteps):
            ax.plot(x[i - 1 : i + 1], y[i - 1 : i + 1], color=colors[i], lw=0.5)
        # Mark the start and end of the trajectory
        ax.plot(x[0], y[0], "g*", markersize=8)  # Start with a green star
        ax.plot(x[-1], y[-1], "ro", markersize=7)  # End with a red circle
    elif plot_dims == 2:
        # Pick two random columns to plot
        ind_x, ind_y = np.random.choice(X.shape[1], 2, replace=False)
        x, y = X[:, ind_x], X[:, ind_y]
        z = np.zeros(max_timesteps)
        ax_labels = (axis_labels[0], axis_labels[1], "Fixed Z")
        # Plot the 2D trajectory
        for i in range(1, max_timesteps):
            ax.plot(x[i - 1 : i + 1], y[i - 1 : i + 1], color=colors[i], lw=0.5)
        # Mark the start and end of the trajectory
        ax.plot(x[0], y[0], "g*", markersize=8)  # Start with a green star
        ax.plot(x[-1], y[-1], "ro", markersize=7)  # End with a red circle
    else:  # plot_dims == 3
        ind_x, ind_y, ind_z = np.random.choice(X.shape[1], 3, replace=False)
        x, y, z = X[:, ind_x], X[:, ind_y], X[:, ind_z]
        ax_labels = axis_labels
        # Plot the 3D trajectory
        for i in range(1, max_timesteps):
            ax.plot(
                x[i - 1 : i + 1],
                y[i - 1 : i + 1],
                z[i - 1 : i + 1],
                color=colors[i],
                lw=0.5,
            )
        # Mark the start and end of the trajectory
        ax.plot(x[0], y[0], z[0], "g*", markersize=8)  # Start with a green star
        ax.plot(x[-1], y[-1], z[-1], "ro", markersize=7)  # End with a red circle

    # Set labels for the axes
    ax.set_xlabel(ax_labels[0])
    ax.set_ylabel(ax_labels[1])
    if plot_dims == 3:
        ax.set_zlabel(ax_labels[2])

    # Set title
    ax.set_title(title)

    # Show the plot
    if show:
        plt.show()
    else:
        return fig, ax

In [4]:
def save_synthetic_dataset(file_name, dataset):
    with open(file_name, "wb") as f:
        pickle.dump(dataset, f)

In [5]:
def calculate_spectral_radius(matrix):
    eigenvalues = np.linalg.eigvals(matrix).real
    spectral_radius = max(abs(eigenvalues))
    return spectral_radius

In [6]:
def adjust_matrix_to_edge_of_chaos(matrix, target_radius=1.0):
    # Calculate the current spectral radius
    current_radius = calculate_spectral_radius(matrix)

    # Calculate the gain needed to adjust the spectral radius to the target
    gain = target_radius / current_radius

    # Scale the matrix by the gain
    adjusted_matrix = matrix * gain

    return adjusted_matrix, gain

#### Get default parameter values from configs and utils

In [7]:
config = OmegaConf.load("../configs/submodule/preprocess.yaml")
DELTA_T = config.preprocess.resample_dt

#### Recurrent Network dataset

Dynamics evolve according to

$$ \tau \frac{d\mathbf{x}}{dt} = -\mathbf{x} + \mathbf{M} f(\mathbf{x}) + \mathbf{b} $$

where $\mathbf{b}$ is a vector of external inputs and $\mathbf{M}$ is a connectivity matrix. $\mathbf{M}$ is a sparse random graph with non-zero weights chosen i.i.d $\sim \mathcal{N}(0,1)$.

We define $\alpha = \frac{\Delta t}{\tau}$. The discrete-time update equation of the RNN is given by

$$ \mathbf{x}(t + \Delta t) = \left(1 - \alpha\right) \mathbf{x}(t) + \alpha \mathbf{M} f\left(\mathbf{x}(t)\right) + \alpha \mathbf{b}(t) $$

**Considerations:**
- The time constant $\tau$ may be different for different neurons in the network, in which case $\mathbf{\tau}$ should be viewed as a vector $\mathbf{\tau}$. Furthermore, $\tau$ could be time-dependent $\tau(t)$, which may be due to neuromodulation.
- For our simulations, and also with real experimental data, we generally know (or can resample) the sampling interval (i.e. measurement timestep) $\Delta t$. But since we don't know (or don't have good estimates of) the time constant $\tau$, this makes $\alpha$ unknown. 
- However, a common assumption from rate-based modeling is that $\Delta t$ is much shorter $\tau$.
- Therefore $\alpha$  (or $\mathbf{\alpha}$ in the vectorized case) is non-negative and bounded $0 < \alpha <1$.
- Using $\alpha=1$ is simpler and makes things easier without changing the expressivity of the network.

---

We will simplify the problem by assuming $\alpha=1$. It turns out that this may not reduce the expressivity as the effect of the time constant is subsumed by the eigenvalues of the matrix $\mathbf{M}$:

$$ \mathbf{x}(t + \Delta t) =  \mathbf{M} f\left(\mathbf{x}(t)\right) + \mathbf{b}(t) $$

In [10]:
def create_synthetic_dataset_recurrent(
    max_timesteps: int = 1000,
    num_worms: int = 1,
    num_signals: int = 10, #NUM_NEURONS,
    num_named_neurons: Union[None, int] = None,
    add_noise: bool = False,
    noise_std: float = 0.01,
    random_walk: bool = False,
    # # >>> any special arguments for this function should go here >>>
    input_signal_gain: float = 1.0,
    intrinsic_noise_gain: float = 0.0,
    num_sensors: int = 1,
    # signal_type: Union[None, str] = "step", # options: {"step", "sine"}
    # # <<< any special arguments for this function should go here <<<
    delta_seconds: float = DELTA_T,
    smooth_method: Union[None, str] = "exponential",
    transform: Union[None, callable] = StandardScaler(),
    dataset_name: str = "Recurrent0000",
):
    """
    Create a synthetic worm datasets using the Lorenz attractor.
    Three neurons are chosen randomly to represent x, y, z trajectories from the Lorenz system.

    :param max_timesteps: The number of timepoints of synthetic data to generate.
    :param num_worms: The number of synthetic worms to create datasets for.
    :param num_signals: The number of signals corresponding to number of neurons.
    :param num_named_neurons: The number of measured neurons to create non-zero signals for.
    :param add_noise: Whether to simulate measurement noise by adding Gaussian noise to the synthetic data.
    :param noise_std: The standard deviation of the i.i.d Gaussian measurement noise.
    :param random_walk: If True, use a random walk to generate the noise. Otherwise, use iid noise.
    :param delta_seconds: The constant time difference (in seconds) between each measurement of the system.
    :param smooth_method: The method to use for smoothing the data.
    :param transform: The sklearn method to scale or transform the data before use.
    :param dataset_name: The name to give the synthetic dataset.
    :return: A dictionary containing the synthetic worm datasets.
    """
    # ### DEBUG ###
    # # For when doing more general (not C. elegans) graphs
    # from string import printable # up to 100 characters
    # NEURON_LABELS = printable[:num_signals]
    # print(NEURON_LABELS)
    # ### DEBUG ###
    eps = np.finfo(float).eps
    dataset = {}
    # Determine the timepoints for sampling the data (i.e. "measurement" times)
    time_in_seconds = delta_seconds * np.arange(max_timesteps).reshape(-1, 1)  # column vector
    # Calculate number of named and unknown neurons
    if num_named_neurons is None or num_named_neurons > num_signals:  # default to all neurons
        num_named_neurons = num_signals
    elif num_named_neurons < 0:  # default to no neurons
        num_named_neurons = 0
    num_unknown_neurons = num_signals - num_named_neurons
    # Define a fixed minimal connectivity matrix which ensures that the network is connected
    sparsity = 0.0
    connected = False
    while not connected:
        sparsity += 1 / num_signals
        sparse_mask = np.random.choice(
            [0, 1], size=(num_signals, num_signals), p=[1 - sparsity, sparsity]
        )
        # Enforcing autapses (diagonal connections) ensures the connectivity matrix is full rank
        sparse_mask[np.diag_indices(num_signals)] = 1
        G = nx.from_numpy_array(sparse_mask)
        connected = nx.is_connected(G)
    connectivity_matrix = sparse_mask * ( eps + np.random.rand(num_signals, num_signals) )  # non-negative weights sampled from Unif(0,1] for simplicity
    # Adjust the connectivity matrix to the edge of chaos
    connectivity_matrix, _ = adjust_matrix_to_edge_of_chaos(connectivity_matrix, target_radius=1.0)
    spectral_radius = calculate_spectral_radius(connectivity_matrix)  # should be close to 1.0
    assert np.isclose(spectral_radius, 1.0, atol=1e-2), f"Spectral radius is not close to 1.0: {spectral_radius}"
    ### DEBUG ###
    # Evolve the dynamics specified by the fixed connectivity and prespecified time constants
    ### DEBUG ###
    # Some warmup timesteps to allow the system to reach a steady state
    warmup_timesteps = max_timesteps // 6
    simulation_steps = max_timesteps + warmup_timesteps
    ### DEBUG ###
    # Calculate the signal-to-noise ratio
    signal_noise_ratio = input_signal_gain / max(
        eps, intrinsic_noise_gain
    )  # avoid division by zero
    # Choose which neurons to be the sensory interface (i.e receive inputs)
    sensorium = sorted(np.random.choice(num_signals, size=num_sensors, replace=False))
    # Signal (i.e. control law) will only be applied at the sensorium
    sensory_mask = np.zeros(num_signals)
    sensory_mask[sensorium] = 1
    # Add a uniqe phase shift to each neuron
    phase_shift = np.random.uniform(low=0, high=2 * np.pi, size=num_signals)
    ### DEBUG ###
    # Create data for each worm
    # TODO: This can be parallelized since each worm is independent
    for worm_idx in range(num_worms):
        # Initialize worm data
        worm = f"worm{worm_idx}"
        worm_data = dict()
        calcium_data = np.zeros((max_timesteps + warmup_timesteps, num_signals))
        # Choose a random subset of neurons to record / observe / measure
        named_neuron_indices = random.sample(
            range(num_signals), num_named_neurons
        )  # without replacement
        named_neurons = set(f"node_{idx}" for idx in named_neuron_indices)
        # Create neuron to idx mapping and vice versa
        neuron_to_idx = {
            (f"node_{idx}") if f"node_{idx}" in named_neurons else str(idx): idx
            for idx in range(num_named_neurons)
        }
        idx_to_neuron = {idx: neuron for neuron, idx in neuron_to_idx.items()}
        # We define "input" to as a signal/no signal applied on top of background intrinsic noise
        input_matrix = np.zeros_like(calcium_data)  # initialize with noise
        # Create calcium data by evolving the dynamics
        for t in range(simulation_steps):
            # Initial conditions
            if t == 0:
                signal = input_signal_gain * np.zeros(num_signals)  # zero input signal at the start
                noise = intrinsic_noise_gain * np.random.randn(num_signals)  # random noise
                # Apply input signal only at the sensorium but intrinsic noise everywhere
                inputs = sensory_mask * signal + noise
                # Set the initial state of the network
                # state = np.zeros(num_signals)  # zeros initialization
                state = np.random.uniform(
                    low=-5.0, high=5.0, size=num_signals
                )  # random initialization
            # Evolve recurrent dynamics
            else:
                # Specify the input signal to apply at each neuron for this timestep
                # TODO: In reality, we want to design the control input u = signal to achieve some desired trajectory.
                # ### DEBUG ###
                # # Sinusoidal input signal # TODO: make an argument of the function
                # freq = 10 # number of cycles in the simulation # TODO: make an argument of the function
                # signal = input_signal_gain * np.sin(2 * np.pi * freq * t / (simulation_steps - 1) + phase_shift/freq) # unique phase shift per neuron
                # ### DEBUG ###
                ### DEBUG ###
                # Step input signal # TODO: make an argument of the function
                modulo = 100  # period (in timesteps) of one cycle # TODO: make an argument of the function
                signal = input_signal_gain * np.heaviside(
                    ((t + modulo * phase_shift) % modulo) - modulo / 2, 0
                )  # unique phase shift per neuron
                ### DEBUG ###
                noise = intrinsic_noise_gain * np.random.randn(num_signals)
                # Apply input signal only at the sensorium but intrinsic noise everywhere
                inputs = sensory_mask * signal + noise
                # Integrate the inputs and current state to get the next state
                state = connectivity_matrix @ np.tanh(state) + inputs
            # Record or 'measure' the state of the named neurons and all the inputs
            calcium_data[t][named_neuron_indices] = state[named_neuron_indices]
            input_matrix[t] = inputs
        # Discard warmup timesteps
        calcium_data = calcium_data[warmup_timesteps:]
        input_matrix = input_matrix[warmup_timesteps:]
        # Add i.i.d measurement/observation noise
        if add_noise:
            for neuron_index in named_neuron_indices:
                if random_walk:
                    noise_walk = np.cumsum(
                        [0]
                        + np.random.normal(loc=0, scale=noise_std, size=max_timesteps - 1).tolist()
                    )
                    calcium_data[:, neuron_index] += noise_walk
                else:
                    noise_iid = np.random.normal(0, noise_std, max_timesteps)
                    calcium_data[:, neuron_index] += noise_iid
        # Normalize the data
        if transform:
            calcium_data = transform.fit_transform(calcium_data)
        # Calculate residuals
        dt = np.diff(time_in_seconds, axis=0, prepend=0.0)
        resample_dt = np.median(dt[1:]).item()
        residual_calcium = np.gradient(calcium_data, time_in_seconds.squeeze(), axis=0)
        # Smooth the data and convert to tensors
        smooth_calcium_data = smooth_data_preprocess(
            calcium_data,
            time_in_seconds,
            smooth_method,
            **dict(alpha=0.5, window_size=15, sigma=5),
        )
        smooth_residual_calcium = smooth_data_preprocess(
            residual_calcium,
            time_in_seconds,
            smooth_method,
            **dict(alpha=0.5, window_size=15, sigma=5),
        )
        # Save the data
        worm_data["worm"] = worm
        worm_data["source_dataset"] = dataset_name
        worm_data["smooth_method"] = smooth_method
        worm_data["calcium_data"] = calcium_data
        worm_data["smooth_calcium_data"] = smooth_calcium_data
        worm_data["residual_calcium"] = residual_calcium
        worm_data["smooth_residual_calcium"] = smooth_residual_calcium
        worm_data["max_timesteps"] = max_timesteps
        worm_data["time_in_seconds"] = time_in_seconds
        worm_data["dt"] = dt
        worm_data["median_dt"] = resample_dt
        worm_data["neuron_to_idx"] = neuron_to_idx
        worm_data["idx_to_neuron"] = idx_to_neuron
        worm_data["num_neurons"] = num_signals
        worm_data["num_named_neurons"] = num_named_neurons
        worm_data["num_unknown_neurons"] = num_unknown_neurons
        worm_data["extra_info"] = {
            "adjacency_matrix": sparse_mask,
            "connection_weights": connectivity_matrix,
            "input_matrix": input_matrix,
            "spectral_radius": spectral_radius,
            "sparsity": sparsity,
            ### DEBUG ###
            "sensorium": sensorium,
            "sensory_mask": sensory_mask,
            "signal_noise_ratio": signal_noise_ratio,
            ### DEBUG ###
            "measurement_noise": add_noise,
            "iid_noise_std": noise_std * int(add_noise),
            "meta_text": "`input_matrix` is a matrix where each row is the pattern of input applied to the network.\n"
            "adjacency_matrix` is binary.\n`connection_weights` has the edge strengths.\n"
            "`spectral_radius` is the maximum absolute value of the eigenvalues of `connection_weights`.\n"
            "`measurement_noise` indicates whether i.i.d Gaussian noise with mean 0 and variance `iid_noise_std`^2 was added to the observed 'neural' data.\n",
        }
        # # Reshape the data to the standardized format
        # worm_data = reshape_calcium_data(worm_data)
        # Save the data
        dataset[worm] = worm_data
    return dataset

##### Matrix $A$ Definition and Eigenvalues

Let $A$ be some sqaure matrix defined as:
$$A = -I + D$$
where $I$ is the identity matrix and $D$ is a general square matrix (not necessarily diagonal) with spectral radius $\rho(D) \leq 1$.

---

##### Spectral Radius and Eigenvalues

The spectral radius of a general matrix $M$ is defined as:
$$\rho(M) = \max_{\lambda \in \sigma(M)} |\lambda|$$
where $\sigma(M)$ denotes the spectrum (set of eigenvalues) of $M$.

Given $\rho(D) \leq 1$, the absolute value of any eigenvalue of $D$ does not exceed 1. This bounds the eigenvalues $\lambda_D$ of $D$ such that:
$$|\lambda_D| \leq 1$$

---

##### Impact on Eigenvalues of $A$

The eigenvalues of $A$, denoted as $\lambda_A$, are not necessarily the simple arithmetic transformations of the eigenvalues of $D$ as in the diagonal case, but they still relate to the eigenvalues of $D$. Specifically, if $\lambda_D$ is an eigenvalue of $D$, then $\lambda_A = -1 + \lambda_D$ is an eigenvalue of $A$, arising from the relationship:
$$\det(A - \lambda I) = \det((-I + D) - \lambda I) = \det(D - (1 + \lambda) I) = 0$$

This results in the eigenvalues  $\lambda_A$ of $A$ being:
$$\lambda_A = -1 + \lambda_D$$

---

##### Stability Analysis

To check the stability:
- If $\lambda_D$ ranges in absolute value up to 1, the real part of $\lambda_A = -1 + \lambda_D$ will be at most zero when $\lambda_D$ reaches its maximum magnitude (and is real).
- This yields the largest real part of $\lambda_A$ as zero when $\lambda_D = 1$.

---

##### Conclusion

With the real part of the eigenvalues of $A$ ranging up to zero, we arrive at a similar conclusion as in the diagonal case:
- The system characterized by $\frac{dx}{dt} = Ax(t)$ is **marginally stable** if the maximum eigenvalue $\lambda_D$ of $D$ reaches exactly 1.
- If all eigenvalues of $D$ have magnitudes strictly less than 1, then all eigenvalues of $A$ will have negative real parts, leading to a **stable** system.

Thus, the system's stability hinges critically on the exact values of $D$'s eigenvalues. If $\rho(D) < 1$ (strictly less), the system is stable. If $\rho(D) = 1$, the system is marginally stable, particularly dependent on the nature and algebraic multiplicity of the eigenvalue(s) at the boundary.

---

In [11]:
# Initialize parameters
max_timesteps = 1000
num_worms = 100
num_signals = 10 #NUM_NEURONS
# Reflect a cost-benefit tradeoff between throughput and accuracy
num_named_neurons = num_signals - num_worms + 1
add_noise = False  # measurement noise
noise_std = 0.01
random_walk = False
input_signal_gain = 1.0
intrinsic_noise_gain = 0.01
num_sensors = 1
delta_seconds = DELTA_T
smooth_method = None
transform = StandardScaler()
dataset_name = "Recurrent0000"

# Creating and saving datasets
dataset = create_synthetic_dataset_recurrent(
    max_timesteps=max_timesteps,
    num_worms=num_worms,
    num_signals=num_signals,
    num_named_neurons=num_named_neurons,
    add_noise=add_noise,
    noise_std=noise_std,
    random_walk=random_walk,
    input_signal_gain=input_signal_gain,
    intrinsic_noise_gain=intrinsic_noise_gain,
    num_sensors=num_sensors,
    delta_seconds=delta_seconds,
    smooth_method=smooth_method,
    transform=transform,
    dataset_name=dataset_name,
)

# # Save the dataset
# save_synthetic_dataset(f"processed/neural/{dataset_name}.pickle", dataset)

In [12]:
dataset

{'worm0': {'worm': 'worm0',
  'source_dataset': 'Recurrent0000',
  'smooth_method': None,
  'calcium_data': array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]),
  'smooth_calcium_data': array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]),
  'residual_calcium': array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]),
  'smooth_residual_calcium': array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0.

#### Helper code for plotting 

In [None]:
def thunk():
    global num_worms, worm_idx, neuron_idx
    # Selecting a worm and all the neurons to plot
    num_worms = len(dataset)
    worm_idx = random.choice([f"worm{i}" for i in range(num_worms)])
    neuron_idx = [idx for idx in dataset[worm_idx]["slot_to_neuron"].keys()][:num_named_neurons]

    # Plotting dataset
    plot_neural_signals(
        data=dataset[worm_idx]["calcium_data"],
        time_tensor=dataset[worm_idx]["time_in_seconds"],
        neuron_idx=neuron_idx,
        yax_limit=False,
        suptitle=f"{dataset_name} - {worm_idx}",
    )

    # Visualize covariance matrix
    data = dataset[worm_idx]["calcium_data"]
    mask = dataset[worm_idx]["named_neurons_mask"]
    neurons = sorted(dataset[worm_idx]["named_neuron_to_slot"])

    # X = data[:, mask].numpy()
    # n = X.shape[0]
    # # centering the data here is redundant if StandardScaler was used when creating the dataset
    # X_bar = X - np.mean(X, axis=0, keepdims=True)
    # cov = 1 / (n - 1) * X_bar.T @ X_bar

    # plt.figure()
    # ax = sns.heatmap(cov, cmap="coolwarm", xticklabels=neurons, yticklabels=neurons)
    # ax.set_title(f"Covariance matrix : {dataset_name}, {worm_idx}")
    # plt.show()

    # # Plotting 3D trajectory
    # plot_3d_trajectory(X, axis_labels=tuple(neurons), title=f"{dataset_name} neural trajectory")

    ### DEBUG ###
    V = data.numpy()
    n = V.shape[0]
    # Centering the data here is redundant if StandardScaler was used when creating the dataset
    V_bar = V - np.mean(V, axis=0, keepdims=True)
    cov = 1 / (n - 1) * V_bar.T @ V_bar
    X = V[:, mask.numpy()]

    heat_mask = (mask.unsqueeze(1).numpy() * 1) @ (mask.unsqueeze(0).numpy() * 1)
    heat_mask = ~heat_mask.astype(bool)
    cmap = sns.color_palette("coolwarm", as_cmap=True)
    cmap.set_bad(color="black")  # Set the color for NaN values

    plt.figure(figsize=(12, 12))
    ax = sns.heatmap(
        cov, cmap=cmap, mask=heat_mask, xticklabels=NEURON_LABELS, yticklabels=NEURON_LABELS
    )
    # Adjust the font size of x and y tick labels
    ax.set_xticklabels(ax.get_xticklabels(), fontsize=4)
    ax.set_yticklabels(ax.get_yticklabels(), fontsize=4)
    ax.set_title(f"Covariance matrix : {dataset_name}, {worm_idx}")
    plt.show()

    # Plotting 3D trajectory
    plot_3d_trajectory(X, axis_labels=tuple(neurons), title=f"{dataset_name} neural trajectory")
    ### DEBUG ###

In [None]:
# Plotting the dataset
thunk()

##### Analysis of recurrent network system recovery from sparse data

Let us rearrange the discrete-time update equation from earlier:

$$ \mathbf{x}(t + \Delta t) =  \mathbf{M} f\left(\mathbf{x}(t)\right) + \mathbf{b}(t) $$

in a way that makes the "push-forward" function more obvious and makes the matrix we wish to recover from the data more explicit. 

**N.B.**: What we refer to as the "push-forward" is the function the maps from the current state at time $t$ to the subsequent state at time $t+\Delta t$.

We will make use of the linear approximation $f(x) \approx x$ for small $x$:

$$ \mathbf{M} \mathbf{x}(t) + \mathbf{b}(t) = \mathbf{x}(t + \Delta t) $$

Right multiplying on both sides by $\mathbf{x}(t)^\intercal$ gives:

$$ \mathbf{M} \mathbf{x}(t)\mathbf{x}(t)^\intercal + \mathbf{b}(t)\mathbf{x}(t)^\intercal = \mathbf{x}(t + \Delta t)\mathbf{x}(t)^\intercal $$

Since $\mathbf{x}(t)$ is a vector we know that the matrices $\mathbf{x}(t)\mathbf{x}(t)^\intercal, \mathbf{b}(t)\mathbf{x}(t)^\intercal, \mathbf{x}(t + \Delta t)\mathbf{x}(t)^\intercal$ are all rank-$1$ and thus not invertible. But since we have multiple samples (a.k.a observations or measurements) of the state $\mathbf{x}(t), t \in \{0,1,2, ...,T\}$ the corresponding sample covariance matrices:

$$
\begin{aligned}
\hat{\mathbf{\Sigma}}_x &= \frac{1}{T} \sum_{t=0}^{T}\mathbf{x}(t)\mathbf{x}(t)^\intercal \\
\hat{\mathbf{\Sigma}}_{bx} &= \frac{1}{T} \sum_{t=0}^{T} {\mathbf{b}(t)\mathbf{x}(t)^\intercal} \\
\hat{\mathbf{\Sigma}}_{{\Delta t}x} &= \frac{1}{T} \sum_{t=0}^{T}\mathbf{x}(t + \Delta t)\mathbf{x}(t)^\intercal
\end{aligned}
$$

will be full rank, and thus invertible, if a minimal set $\geq \operatorname{dim}(\mathbf{x})$ of linearly independent states are sampled. If suffficiently rich dynamics are present then the probability of this being true $\to 1$ as $T \to \infty$.

$$
\begin{aligned}
\hat{\mathbf{M}}\hat{\mathbf{\Sigma}}_x + \hat{\mathbf{\Sigma}}_{bx} &= \hat{\mathbf{\Sigma}}_{{\Delta t}x} \\
\rightarrow \hat{\mathbf{M}} &= \left( \hat{\mathbf{\Sigma}}_{{\Delta t}x} - \hat{\mathbf{\Sigma}}_{bx} \right) \left( \hat{\mathbf{\Sigma}}_x \right)^{-1}
\end{aligned}
$$

We can think of the external input $\mathbf{b}$ as the environment, which is impossible to measure perfectly and often not measured at all. However, we can simplify the problem of recovering the matrix $\mathbf{B}$ by assuming either: 
 1. no external input $\mathbf{b}(t) = \mathbf{0}$, or
 2. the input and state are uncorrelated $\hat{\mathbf{\Sigma}}_{bx} = \mathbf{0}$.

**N.B.** Neither of these assumptions are fully valid in the real world but the second one is a slightly more reasonable. The second assumption essentially says that the environment is independent of the network state.

With the above assumptions, the equation for recovering the matrix $\mathbf{M}$ from observations/measurement data simplifies to:
$$
\hat{\mathbf{M}} = \hat{\mathbf{\Sigma}}_{{\Delta t} x} \left( \hat{\mathbf{\Sigma}}_x \right)^{-1}
$$


In [None]:
### DEBUG ###
s = dataset[worm_idx]["extra_info"]["sensorium"]  # (num_sensors,)
print(f"sensory neurons: {np.array(NEURON_LABELS)[s]}\n")

not_s = list(set(range(num_signals)) - set(s))  # (num_signal - num_sensors,)
print(f"non-sensory neurons: {np.array(NEURON_LABELS)[not_s]}\n")

b = dataset[worm_idx]["extra_info"]["input_matrix"]  # (T, num_signals)
print(f"input matrix shape: {b.shape}\n")

X = dataset[worm_idx]["calcium_data"].numpy()  # (T, num_signals)
print(f"measurement data shape: {X.shape}\n")

if len(s) > 0:
    plt.figure()
    plt.plot(b[:, s[0]])
    plt.title(f"input at {np.array(NEURON_LABELS)[s[0]]}")
    plt.show()

    plt.figure()
    plt.plot(X[:, s[0]])
    plt.title(f"measurement at {np.array(NEURON_LABELS)[s[0]]}")
    plt.show()

if len(not_s) > 0:
    plt.figure()
    plt.plot(b[:, not_s[0]])
    plt.title(f"input at {np.array(NEURON_LABELS)[not_s[0]]}")
    plt.show()

    plt.figure()
    plt.plot(X[:, not_s[0]])
    plt.title(f"measurement at {np.array(NEURON_LABELS)[not_s[0]]}")
    plt.show()
### DEBUG ###

In [None]:
cov_x = np.zeros((num_signals, num_signals))
cov_dtx = np.zeros((num_signals, num_signals))
cov_bx = np.zeros((num_signals, num_signals))  # oracle knowledge
total_mask = np.zeros((num_signals, num_signals))
M = None
cap = 200  # < num_worms
for idx in range(min(cap, num_worms)):
    # Get current worm ID
    worm_idx = f"worm{idx}"
    # Check that "connectome" is the same for all worms
    if M is None:
        M = dataset[worm_idx]["extra_info"]["connection_weights"]
        A = dataset[worm_idx]["extra_info"]["adjacency_matrix"]
    else:
        M_ = dataset[worm_idx]["extra_info"]["connection_weights"]
        assert np.allclose(M, M_), "Inconsisent connection weights!"
        M = M_
        A_ = dataset[worm_idx]["extra_info"]["adjacency_matrix"]
        assert np.allclose(A, A_), "Inconsistent adjacency matrix!"
        A = A_
    # Compute matrices needed to estimate the connectivity matrix
    mask = dataset[worm_idx]["named_neurons_mask"].numpy().reshape(1, -1)
    data = dataset[worm_idx]["calcium_data"].numpy()
    # Having the inputs requires an oracle or perfect meaasurement of the environment
    inputs = dataset[worm_idx]["extra_info"]["input_matrix"]  # (T, num_signals)
    b = inputs
    X = data
    T = data.shape[0]
    S = mask.T @ mask
    total_mask += S
    # # (num_signals, T) x (T, num_signals) -> (num_signals, num_signals)
    # cov_x += ( (X[:-1].T @ X[:-1]) / T ) * S
    # cov_dtx += ( (X[1:].T @ X[:-1]) / T ) * S
    # cov_bx += ( (b[:-1].T @ X[:-1] ) / T ) * S
    ### DEBUG ###
    # Idea: Reduced dependence between spaced apart state samples
    inds = np.unique(np.linspace(1, T - 1, 2 * num_signals, dtype=int))
    cov_x += ((X[inds - 1].T @ X[inds - 1]) / len(inds)) * S
    cov_dtx += ((X[inds].T @ X[inds - 1]) / len(inds)) * S
    cov_bx += ((b[inds - 1].T @ X[inds - 1]) / len(inds)) * S
    ### DEBUG ###
total_mask = np.clip(total_mask, a_min=1, a_max=None)

# Average covariance matrices over repetitions of the same neuron across all worms
cov_x = np.divide(cov_x, total_mask)
cov_dtx = np.divide(cov_dtx, total_mask)
cov_bx = np.divide(cov_bx, total_mask)

##################################################################################################################################

# Estimator of M matrix
approx_M = cov_dtx @ np.linalg.pinv(cov_x)

# Plot figures
plt.figure()
ax = sns.heatmap(M, cmap="coolwarm")
ax.set_title(f"Ground truth connectivity matrix : {dataset_name}")
plt.show()

plt.figure()
ax = sns.heatmap(approx_M, cmap="coolwarm")
ax.set_title(f"Approximate connectivity matrix : {dataset_name}")
plt.show()

##################################################################################################################################

print(f"total_mask: {total_mask}\n")
print(f"num_worms: {num_worms}\n")
print("number of neurons never observed:", (np.diag(total_mask) <= 1).sum(), end="\n\n")

print("cov_x symmetric:", scp.linalg.issymmetric(cov_x))
print("cov_x rank:", np.linalg.matrix_rank(cov_x))
print("cov_x determinant", np.linalg.det(cov_x))
print("~" * 50)
print("cov_dtx symmetric:", scp.linalg.issymmetric(cov_dtx))
print("cov_dtx rank:", np.linalg.matrix_rank(cov_dtx))
print("cov_dtx determinant", np.linalg.det(cov_dtx), end="\n\n")

print(
    f"True M diagonal (min, max): {np.diag(M).min(), np.diag(M).max()}\n"
    f"Approximate M diagonal (min, max): {np.diag(approx_M).min(), np.diag(approx_M).max()}\n"
)
print(
    f"True M eigenvalues (min, max): {np.linalg.eigvals(M).real.min(), np.linalg.eigvals(M).real.max()}\n"
    f"Approximate M eigenvalues (min, max): {np.linalg.eigvals(approx_M).real.min(), np.linalg.eigvals(approx_M).real.max()}\n",
    end="\n\n",
)

# ### DEBUG ###
# # Estimate of M matrix given oracle knowledge of the input b
# # TODO: WHY IS THIS WORSE THAN THE NON-ORACLE APPROXIMATION?!
# # HYPOTHESIS: It may be better at non-sensorium neurons.
# oracle_M = (cov_dtx - cov_bx) @ np.linalg.pinv(cov_x)
# plt.figure()
# ax = sns.heatmap(oracle_M, cmap="coolwarm")
# ax.set_title(f"Oracle M matrix : {dataset_name}")
# plt.show()
# print(f"Oracle M iagonal (min, max): {np.diag(oracle_M).min(), np.diag(oracle_M).max()}\n")
# print(
#     f"Oracle M eigenvalues (min, max): {np.linalg.eigvals(oracle_M).real.min(), np.linalg.eigvals(oracle_M).real.max()}\n",
#     end="\n\n",
# )
# print(
#     "estimate oracle distance (M - oracle_M):", np.linalg.norm(M - oracle_M, ord="fro"), end="\n\n"
# )
# ### DEBUG ###

##################################################################################################################################

print("estimate full distance (M - approx_M):", np.linalg.norm(M - approx_M, ord="fro"))

print(
    "(distribution) chance distance:", np.linalg.norm(M - np.random.randn(*M.shape), ord="fro")
)  # if all you knew was that the weights were ~N(0,1)
print(
    "(adjacency + distribution) chance distance:",
    np.linalg.norm(M - A * np.random.randn(*M.shape), ord="fro"),
)  # if you additionally knew the adjacency matrix
print(
    "(sign + adjacency + distribution) chance distance:",
    np.linalg.norm(M - A * np.sign(M) * np.abs(np.random.randn(*M.shape)), ord="fro"),
    end="\n\n",
)  # if you additionally knew the signs of the weights

##################################################################################################################################

**Key takeaways:**

* Solving for the connectivity matrix is impossible without input noise, even though when we wrote the equation we ignore the noise because we don't know it!

* Minimally connectedness seems to be extremely necessary to be able to estimate the connectivity matrix. Densely connected networks are near impossible to estimate. Although we don't understand mathematically why, it is likely because of the multiple pathways that influence a node's activity. 

* It seems more important to have more recorded neurons per worm and have fewer worms than to have have many worms with only a few recorded neurons. So the optimizing the number of neurons recorded per animal is more important than improving the throughput of animals.
    * Update: It is still possible to estimate densely connected networks but there now needs to be noise injected at almost every node independently.
    * So it seems that the less sparse/more dense/more connected the network is, the larger the percentage of the network that is the sensorium needs to be so that noise can get injected there.

* A critical threshold of sensory neurons (i.e. nodes alowed to receive external inputs) needs to be passed ($\approx 200$ out of the $300$ at maximum sparsity for connectedness) for the connections weights to be recoverable. It remains to be seen whether the external inputs $b$ must be Gaussian white nose or whether they can be a time-varying smooth signal.
    * System recovery works optimially for time-independent Gaussian noise input signal.
    * System recovery fails for high-frequency time-varying input signal, even when frequency is on the order of the number of measurement timesteps.
    * Adding noise to a high-frequency time-varying input signal make improves system recovery.
    * System recovery fails for a low-frequency time-varying input signal.
    * Adding noise to a low-frequency time-varying input signal allows some system recovery.

* So really all you need is away to inject a significant amount of noise into your system and you can recover the weights. The signal-to-noise ratio for the external input needs to be quite small for us to be able to estimate the connectivity weights.
    
    