# Neural space-time model example on structured illumination microscopy (SIM) data
Jupyter lab demo for "Neural space-time model for dynamic multi-shot imaging" by Ruiming Cao, et al. (2024)

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import os
import numpy as np
import jax
import jax.numpy as jnp
from flax import linen as nn
import matplotlib.pyplot as plt

from ipywidgets import interact, IntSlider

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import calcil as cc

from nstm import sim3d_utils
from nstm import sim3d_flow
from nstm import spacetime
from nstm import utils

## Load data

In [None]:
with np.load(os.path.join('./sim_data_beads.npz')) as d:
    timestamp_phase = d['timestamp_phase']
    img = d['img']
    OTF = d['OTF']

### Visualize raw images

In [None]:
f, ax = plt.subplots(1,1,figsize=(10, 4))
frame = []

def init():
    frame.append(ax.imshow(img[0, 0, 0], cmap='gray'))
    ax.axis('off')
    f.tight_layout()

init()
def updateFrame(ori, phase):
    frame[0].set_data(img[ori, phase, 0])

interact(updateFrame, ori=IntSlider(min=0, max=2, step=1, value=0),
         phase=IntSlider(min=0, max=4, step=1, value=0))

## Define system parameters

In [None]:
zoomfact = 1.6
dim_zyx = (1, int(200 * zoomfact), int(200 * zoomfact))
dim_otf_zx = (101, 256)
padding_zyx = (0, 0, 0)
wavelength = 0.515
wavelength_exc = 0.488
RI_medium = 1.33
ps = 0.040625
dz = 0.15
dz_otf = 0.1
na = 1.28
optical_param = utils.SystemParameters3D(dim_zyx, wavelength, wavelength_exc, na,
                                         ps / zoomfact, dz, RI_medium, padding_zyx)
otf_param = utils.SystemParameters3D((dim_otf_zx[0], dim_otf_zx[1], dim_otf_zx[1]),
                                     wavelength, wavelength_exc, na, ps,
                                     dz_otf, RI_medium, (0, 0, 0))
line_spacing = [0.419490, 0.419520, 0.419330]
k0angles = [1.782, 2.897, -2.379]

## Set up OTF

In [None]:
bandplus_img = np.array([sim3d_utils.separate_bands(img_, out_positive_bands=True) for img_ in img])

sim_param, phase, amp = sim3d_utils.get_otf(bandplus_img[:, :, :optical_param.dim_zyx[0]], OTF, optical_param,
                                            otf_param, ndirs=3, nphases=5, k0angles=k0angles, line_spacing=line_spacing,
                                            crop_boundary_zyx=(2 if optical_param.dim_zyx[0] > 10 else 0, 20, 20),
                                            noisy=True, notch=False)
otf_valid_region, otf_valid_region_rfft = sim3d_utils.otf_support_mask(optical_param, otf_param, sim_param, OTF[0])

## Reconstruction via neural space-time model

### Define NSTM parameters

In [None]:
object_hash_base = np.array([1, 25, 25])
object_hash_fine = np.array([1, 250, 250])
motion_hash_base = np.array([1, 1, 1, 3])
motion_hash_fine = np.array([1, 1, 1, 15])

hash_param = spacetime.HashParameters(
    bounding_box=(np.array([0, -optical_param.dim_zyx[1]*1./2, -optical_param.dim_zyx[2]*1./2]),
                  np.array([optical_param.dim_zyx[0] * 2, optical_param.dim_zyx[1] * (1 + 1./2), optical_param.dim_zyx[2] * (1 + 1./2)])),
    n_levels=8, n_features_per_level=2, log2_hashmap_size=16, base_resolution=object_hash_base*2, finest_resolution=object_hash_fine*2)
hash_param_motion_spacetime = spacetime.HashParameters(
    bounding_box=(np.array([0, 0, 0, -1]), np.array([optical_param.dim_zyx[0] * 2, optical_param.dim_zyx[1], optical_param.dim_zyx[2], 1])),
    n_levels=8, n_features_per_level=2, log2_hashmap_size=16, base_resolution=motion_hash_base, finest_resolution=motion_hash_fine)

object_mlp_param = spacetime.MLPParameters(net_depth=2, net_width=128, net_activation=nn.gelu)
motion_mlp_param = spacetime.MLPParameters(net_depth=2, net_width=32, net_activation=nn.elu)

space_time_param = spacetime.SpaceTimeParameters(motion_mlp_param=motion_mlp_param,
                                                    object_mlp_param=object_mlp_param,
                                                    motion_embedding='hash_combined',
                                                    motion_embedding_param=hash_param_motion_spacetime,
                                                    object_embedding='hash', object_embedding_param=hash_param,
                                                    out_activation=lambda x: x)


### Data loader

In [None]:
ind_k0angle = np.tile(np.arange(3)[:, np.newaxis], (1, 5))
ind_phases = np.tile(np.arange(5)[np.newaxis, :], (3, 1))
z_offset = np.ones((3, 5)) * (optical_param.dim_zyx[0] // 2 + optical_param.padding_zyx[0])
zyx_offset = np.stack([z_offset, np.zeros_like(z_offset), np.zeros_like(z_offset)], axis=-1)

data_loader = cc.data_utils.loader_from_numpy(
    {'img': img.reshape((-1,) + optical_param.dim_zyx),
      't': timestamp_phase.reshape(-1),
      'ind_phase': ind_phases.reshape((-1)),
      'ind_k0angle': ind_k0angle.reshape(-1),
      'zyx_offset': zyx_offset.reshape((-1, 3))},
    prefix_dim=(5,), seed=85471)

sample_epoch_input = next(data_loader)
num_batches_per_epoch = len(sample_epoch_input)
sample_input_dict = sample_epoch_input[0]

### Initialize NSTM model

In [None]:
num_epoch = 2000
annealed_rate = 0.8

model = sim3d_flow.SIM3DSpacetime(
    sim_param, space_time_param, optical_param,
    annealed_epoch=num_epoch * annealed_rate)

rng = jax.random.PRNGKey(98321)
variables = model.init(rng, input_dict=sample_input_dict)

### Define loss function

In [None]:
l2_loss = cc.loss.Loss(sim3d_flow.gen_loss_l2_stack(margin=2), 'l2')
total_loss = l2_loss

### Run reconstruction

In [None]:
save_path = './checkpoint/SIM_beads/'

recon_param = cc.reconstruction.ReconIterParameters(save_dir=save_path, n_epoch=num_epoch,
                                                    keep_checkpoints=1, 
                                                    checkpoint_every=num_epoch,
                                                    output_every=100, log_every=100)
no_update_params = cc.reconstruction.ReconVarParameters(lr=0)
object_mlp_params = cc.reconstruction.ReconVarParameters(
    lr=1e-3, opt='adam', opt_kwargs={'b1': 0.9, 'b2': 0.99, 'eps': 1e-15}, schedule='exponential',
    schedule_kwargs={'transition_steps': num_epoch * num_batches_per_epoch, 'decay_rate': 0.1, 'transition_begin': 0}, update_every=1)
motion_mlp_params = cc.reconstruction.ReconVarParameters(
    lr=1e-5, opt='adam', opt_kwargs={'b1': 0.9, 'b2': 0.99, 'eps': 1e-15}, schedule='exponential',
    schedule_kwargs={'transition_steps': num_epoch * num_batches_per_epoch, 'decay_rate': 0.1, 'transition_begin': 0}, update_every=1)

var_params = {'params': {'spacetime': {'motion_mlp': motion_mlp_params, 'object_mlp': object_mlp_params,
                                       'motion_embedding': motion_mlp_params, 'object_embedding': object_mlp_params},
                         'fluo_forward': no_update_params}}

recon_variables, recon = cc.reconstruction.reconstruct_multivars_sgd(model.apply, variables, var_params,
                                                                     data_loader, total_loss, recon_param,
                                                                     None, None)

In [None]:
recon_t = np.array([model.apply(
    recon_variables, np.array([timestamp_phase.reshape(-1)[i]]),
    method=lambda module, t: module.spacetime(
        t, np.array([[z_offset[0, 0], 0, 0]]))[0, ..., 0]) for i in range(15)])
recon_t = (
    np.fft.ifftn(otf_valid_region[np.newaxis, :, :, :] * np.fft.fftn(recon_t, axes=(-3, -2, -1)), axes=(-3, -2, -1)).real).astype(
    np.float32)

### Visualize the reconstruction

In [None]:
f, ax = plt.subplots(figsize=(6, 6))
frame = ax.imshow(recon_t[0, 0], clim=(np.percentile(recon_t, 1), np.percentile(recon_t, 99.9)), cmap='gray')
ax.axis('off')
f.tight_layout()

def update(t):
    frame.set_data(recon_t[t, 0])

interact(update, t=IntSlider(min=0, max=14, step=1, value=0))

### Turning off the motion update
To compare with the static reconstruction, you may turn off the motion update by zeroing out the motion network's learning rate in 'motion_mlp_params' and all timepoint values in 'data_loader'.

## Reference