In [None]:
import os
from collections.abc import Sequence
import matplotlib.pyplot as plt
import time
from IPython.display import clear_output

import flax.jax_utils as flax_utils
import flax.linen as nn
import grain.python as grain
import jax
import numpy as np
from absl import logging
from connectomics.jax import checkpoint, training
from etils import epath
from orbax import checkpoint as ocp

import zapbench.models.util as model_util
from zapbench.ts_forecasting import heads, input_pipeline, train
from zapbench.ts_forecasting.configs import infer, linear


def _get_checkpoint_step(
    checkpoint_manager: ocp.CheckpointManager,
    selection_strategy: str,
) -> int | None:
  """Returns the checkpoint step to use given a selection strategy.

  Args:
    checkpoint_manager: Checkpoint manager.
    selection_strategy: Checkpoint selection strategy, can be 'early_stopping',
      'best_val_loss', or 'latest'.

  Returns:
    Checkpoint step.
  """
  if selection_strategy == 'early_stopping':
    checkpointed_state = dict(
        early_stop=None,
    )
    checkpointed_state = checkpoint.restore_checkpoint(
        checkpoint_manager,
        state=checkpointed_state,
        step=checkpoint_manager.latest_step(),
    )
    return checkpointed_state['early_stop']['best_step']
  elif selection_strategy == 'best_val_loss':
    checkpointed_state = dict(
        track_best_val_loss_step=None,
    )
    checkpointed_state = checkpoint.restore_checkpoint(
        checkpoint_manager,
        state=checkpointed_state,
        step=checkpoint_manager.latest_step(),
    )
    return checkpointed_state['track_best_val_loss_step']['best_step']
  elif selection_strategy == 'latest':
    return checkpoint_manager.latest_step()
  else:
    raise ValueError(f'Unknown checkpoint selection: {selection_strategy}')


def infer_single_step(
    model: nn.Module,
    head: heads.Head,
    train_state: train.TrainState,
    data_source: grain.RandomAccessDataSource,
    idx: int,
    infer_key: jax.Array,  # pylint: disable=unused-argument
    covariates: Sequence[str] = (),
    covariates_static: jax.Array | None = None,
    with_carry: bool = False,
) -> tuple[jax.Array, jax.Array]:
  """Runs independent inference on each index in the test set.

  Returns:
    prediction: prediction array
    target: target array
  """
  carry = None

  batch = data_source[idx]
  if 'covariates_static' in covariates:
    batch['covariates_static'] = covariates_static

  out = train.pred_step(
      model,
      train_state,
      batch,
      covariates,
      initial_carry=carry,
      return_carry=with_carry,
  )

  if not with_carry:
    dist = head.get_distribution(out)
  else:
    carry, dist = out[0], head.get_distribution(out[1])

  prediction = dist.mode()
  target = batch['timeseries_output']

  return prediction, target

In [None]:
exp_workdir = '/Users/s/vault/zapbench/linear/05'
exp_config = model_util.load_config(os.path.join(exp_workdir, 'config.json'))

model = model_util.model_from_config(exp_config)

covariates_static = input_pipeline.get_static_covariates(exp_config)

checkpoint_manager = checkpoint.get_checkpoint_manager(
    exp_workdir,
    item_names=(
        'early_stop',
        'train_state',
        'track_best_val_loss_step',
    ),
)

step = _get_checkpoint_step(checkpoint_manager, 'best_val_loss')

checkpointed_state = dict(
    train_state=None,
)
checkpointed_state = checkpoint.restore_checkpoint(
    checkpoint_manager,
    state=checkpointed_state,
    step=step,
)
train_state = checkpointed_state['train_state']
train_state = train.TrainState(
    **train_state
)
train_state = flax_utils.replicate(train_state)

In [None]:
# inference
infer_config = infer.get_config()
config = linear.get_config('dataset_name=subject_14')
config.update(infer_config)

head = heads.create_head(config)
rng = training.get_rng(config.seed)
rng, infer_rng = jax.random.split(rng)
infer_source = input_pipeline.create_inference_source_with_transforms(config)
infer_key = jax.random.fold_in(key=infer_rng, data=step)

In [None]:
all_predictions_0, all_targets_0 = [], []
all_predictions_32, all_targets_32 = [], []
cumulative_abs_error = []
target_varibility = []
for infer_idx_set in config.infer_idx_sets:
  name, idx_list = (infer_idx_set[k] for k in ('name', 'idx_list'))
  infer_metrics = None
  train_state = train.merge_batch_stats(train_state)
  for i, idx in enumerate(idx_list):
    prediction, target = infer_single_step(
        model,
        head,
        flax_utils.unreplicate(train_state),
        infer_source,
        idx,
        infer_key=infer_key,
        covariates=tuple(config.covariates),
        covariates_static=covariates_static,
        with_carry=config.infer_with_carry,
    )
    # all_predictions_0.append(prediction[0,0])
    # all_targets_0.append(target[0,0])
    # all_predictions_32.append(prediction[0,31])
    # all_targets_32.append(target[0,31])
    cumulative_abs_error.append(np.abs(prediction[0,0]-target[0,0]))
    target_varibility.append(target[0,0])

    if f'infer_{name}' in head.metrics:
      metrics_update = head.metrics[
          f'infer_{name}'
      ].single_from_model_output(predictions=prediction, targets=target)
      infer_metrics = (
          metrics_update
          if infer_metrics is None
          else infer_metrics.merge(metrics_update)
      )
  break
  if infer_metrics is not None:
    infer_metrics_cpu = jax.tree.map(np.array, infer_metrics.compute())
    print(infer_metrics_cpu)


In [None]:
# all_predictions_0 = np.array(all_predictions_0)
# all_targets_0 = np.array(all_targets_0)
# all_predictions_32 = np.array(all_predictions_32)
# all_targets_32 = np.array(all_targets_32)
cumulative_abs_error = np.array(cumulative_abs_error)
target_varibility = np.var(target_varibility, axis=0)

In [None]:
import matplotlib.pyplot as plt
import scienceplots
plt.style.use(['science', 'no-latex'])

In [None]:
i_neuron_list = range(0, 10000, 500)

for i_neuron in i_neuron_list:
  print(i_neuron)
  fig, ax = plt.subplots(figsize=(4, 1), dpi=200)
  ax.plot(prediction[0, :, i_neuron])
  ax.plot(target[0, :, i_neuron])
  ax.set_ylim(0, 0.3)

In [None]:
target.shape

In [None]:
def smooth_target(target, window_size=10):
    kernel = np.ones(window_size) / window_size
    smoothed = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode='valid'), axis=1, arr=target)
    return smoothed

target_smoothed = smooth_target(target, window_size=5)

In [None]:
for i_neuron in i_neuron_list:
  fig, ax = plt.subplots(figsize=(4, 1), dpi=200)
  ax.plot(prediction[0, :, i_neuron])
  ax.plot(target_smoothed[0, :, i_neuron])
  ax.set_ylim(0, 0.3)

In [None]:
abs_diff = np.abs(prediction - target)[0, :, :]
plt.figure(figsize=(5, 5), dpi=200)
plt.imshow(abs_diff.T, aspect='auto', cmap='magma', origin='lower', vmin=0, vmax=0.2)
plt.colorbar()
plt.axis('off')
plt.show()

Show error on anatomy

In [None]:
import scipy, h5py

reference_anat = scipy.io.loadmat('/Users/s/vault/neural_data/janelia/Additional_mat_files/ReferenceBrain.mat')

PATH = "/Users/s/vault/neural_data/janelia"
subject_id = 14

h5_path = f"{PATH}/subject_{subject_id:02d}"
print(h5_path)
h5 = h5py.File(f"{h5_path}/TimeSeries.h5", "r")
abs_ix = h5['absIX']
abs_ix = (abs_ix[0] - 1).astype(int)

mat_path = f"{PATH}/subject_{subject_id:02d}/data_full.mat"
data_struct = scipy.io.loadmat(mat_path)['data'][0, 0]

all_cell_coordinates = data_struct[8]
coordinates = all_cell_coordinates[abs_ix]

plt.figure(figsize=(10, 10))
plt.imshow(np.sum(reference_anat['anat_stack_norm'], axis=-1).T, cmap='gray')
plt.scatter(coordinates[::][:, 0], coordinates[::][:, 1], c='red', s=0.2)
plt.axis('off')
plt.show()

In [None]:
abs_deviation = np.abs(prediction[0]-target[0])
abs_deviation.shape

In [None]:
for i in range(0, 500, 10):
  fig, axs = plt.subplots(3, 1, figsize=(20, 10))
  # Plot predictions
  axs[0].imshow(np.sum(reference_anat['anat_stack_norm'], axis=-1).T, cmap='gray')
  sc0 = axs[0].scatter(
      coordinates[:, 0], coordinates[:, 1],
      c=np.log(all_predictions[i]),
      cmap='coolwarm',
      s=0.1,
      vmin=np.log(abs_deviation).min(),
      vmax=np.log(abs_deviation).max()
  )
  axs[0].axis('off')

  axs[1].imshow(np.sum(reference_anat['anat_stack_norm'], axis=-1).T, cmap='gray')
  sc1 = axs[1].scatter(
      coordinates[:, 0], coordinates[:, 1],
      c=np.log(all_targets[i]),
      cmap='coolwarm',
      s=0.1,
      vmin=np.log(abs_deviation).min(),
      vmax=np.log(abs_deviation).max()
  )
  axs[1].axis('off')

  axs[2].imshow(np.sum(reference_anat['anat_stack_norm'], axis=-1).T, cmap='gray')
  sc1 = axs[2].scatter(
      coordinates[:, 0], coordinates[:, 1],
      c=np.log(np.abs(all_targets[i]-all_predictions[i])),
      cmap='magma',
      s=0.1,
      vmin=np.log(abs_deviation).min(),
      vmax=np.log(abs_deviation).max()
  )
  axs[2].axis('off')
  plt.show()
  clear_output(wait=True)
  plt.close()

Error t=1 and t=32

In [None]:
for i in range(0, 100, 20):
  fig, axs = plt.subplots(2, 1, figsize=(20, 10))
  axs[0].imshow(np.sum(reference_anat['anat_stack_norm'], axis=-1).T, cmap='gray')
  sc0 = axs[0].scatter(
      coordinates[:, 0], coordinates[:, 1],
      c=np.log(np.abs(all_targets_0[i]-all_predictions_0[i])),
      cmap='coolwarm',
      s=0.1,
      vmin=np.log(abs_deviation).min(),
      vmax=np.log(abs_deviation).max()
  )
  axs[0].axis('off')

  axs[1].imshow(np.sum(reference_anat['anat_stack_norm'], axis=-1).T, cmap='gray')
  sc1 = axs[1].scatter(
      coordinates[:, 0], coordinates[:, 1],
      c=np.log(np.abs(all_targets_32[i]-all_predictions_32[i])),
      cmap='coolwarm',
      s=0.1,
      vmin=np.log(abs_deviation).min(),
      vmax=np.log(abs_deviation).max()
  )
  axs[1].axis('off')

  plt.show()
  clear_output(wait=True)
  plt.close()

Cumulative error t=0

In [None]:
cumulative_abs_error.mean(0)
fig, axs = plt.subplots(1, 1, figsize=(10, 5), dpi=300)
axs.imshow(np.sum(reference_anat['anat_stack_norm'], axis=-1).T, cmap='gray')
sc0 = axs.scatter(
    coordinates[:, 0], coordinates[:, 1],
    c=np.log(cumulative_abs_error.mean(0)),
    cmap='Reds',
    s=0.2,
    alpha=0.5,
)
plt.colorbar(sc0, ax=axs, fraction=0.025, pad=0.04, aspect=30, shrink=1.0).ax.tick_params(labelsize=18)
plt.tight_layout()
axs.axis('off');

Compare with average variability per region (ie, is the model error meaningless in the sense that it is just a reflection of the variability of the data?) 

In [None]:
fig, axs = plt.subplots(1, 1, figsize=(10, 5), dpi=300)
axs.imshow(np.sum(reference_anat['anat_stack_norm'], axis=-1).T, cmap='gray')
sc0 = axs.scatter(
    coordinates[:, 0], coordinates[:, 1],
    c=np.log(target_varibility),
    cmap='coolwarm',
    s=0.2,
)
plt.tight_layout()
axs.set_title('Target variability', fontsize=25)
axs.axis('off');

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# plot τ^(β-1) for β=0.1, 0.5, 0.7

# Create a figure and axis
fig, ax = plt.subplots(figsize=(10, 5), dpi=300)

# Define the range of β values
betas = [0.1, 0.5]

# Plot τ^(β-1) for each β
for beta in betas:
    tau = np.linspace(0.01, 10, 1000)
    ax.plot(tau, tau**(beta-1), label=f'β={beta}')

# Add labels and legend
ax.legend()
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('τ', fontsize=20)
ax.set_ylabel('τ^(β-1)', fontsize=20)
ax.set_title('τ^(β-1) for different β', fontsize=25)
ax.tick_params(labelsize=18)
plt.show()
