# Neural space-time model example on differential phase contrast microscopy (DPC) data
Jupyter lab demo for "Neural space-time model for dynamic multi-shot imaging" by Ruiming Cao, et al. (2024)

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import os
import numpy as np
import jax
import jax.numpy as jnp
from flax import linen as nn
import matplotlib.pyplot as plt

from ipywidgets import interact, IntSlider

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="2"

import calcil as cc

from nstm import dpc_utils
from nstm import dpc_flow
from nstm import spacetime
from nstm import utils
from nstm.hash_encoding import HashParameters

## Load data

In [None]:
data_path = './dpc_data.npz'
with np.load(data_path) as d:
    img = d['img']
    s = d['s']
num_frames = 4
wavelength = 0.525
na = 0.25
pixel_size = 0.454

dim_yx = (320, 1000)
extent = [0, pixel_size * dim_yx[1], pixel_size * dim_yx[0], 0]

param = utils.SystemParameters(dim_yx, wavelength, na, pixel_size, RI_medium=1.0, padding_yx=(0,0))

## Visualize images

In [None]:
f, ax = plt.subplots(1,1,figsize=(10, 4))
frame = []

def init():
    frame.append(ax.imshow(img[0], clim=(-1., 1.), cmap='gray'))
    ax.axis('off')
    f.tight_layout()

init()
def updateFrame(i):
    frame[0].set_data(img[i])

interact(updateFrame, i=IntSlider(min=0, max=3, step=1, value=0),)

## DPC transfer function

In [None]:
pupil = cc.physics.wave_optics.genPupil(param.dim_yx, param.pixel_size, NA=param.na, wavelength=param.wavelength)
Hu, Hp = dpc_utils.gen_transfer_func(list_source=s, pupil=pupil, wavelength=param.wavelength, shifted_out=False)

f, axes = plt.subplots(2, s.shape[0], figsize=(12, 6))
for i in range(s.shape[0]):
    axes[0, i].imshow(np.fft.fftshift(Hu[i].real), extent=(0,1,0,1))
    axes[1, i].imshow(np.fft.fftshift(Hp[i].imag), extent=(0,1,0,1))

## Baseline reconstruction

In [None]:
abs_baseline, phase_baseline = dpc_utils.dpc_tikhonov_solver(img, Hu, Hp, 1e-4, 1e-4, param.wavelength)

## Reconstruction via neural space-time model

### Define NSTM parameters

In [None]:
# model parameters
object_fine_hash_ratio = 0.4
object_base_hash_ratio = 0.05
motion_hash_ratio = 0.03
motion_hash_temporal = 2

hash_param = HashParameters(bounding_box=(np.array([-param.dim_yx[0] * 0.5, -param.dim_yx[1] * 0.5]),
                                          np.array([param.dim_yx[0] * 1.5, param.dim_yx[1] * 1.5])),
                            n_levels=8, n_features_per_level=2, log2_hashmap_size=16,
                            base_resolution=np.array([dim_yx[0]*object_base_hash_ratio*2, dim_yx[1]*object_base_hash_ratio*2]),
                            finest_resolution=np.array([dim_yx[0]*object_fine_hash_ratio*2, dim_yx[1]*object_fine_hash_ratio*2]))
hash_param_motion_spacetime = HashParameters(
    bounding_box=(np.array([0, 0, -1]), np.array([param.dim_yx[0], param.dim_yx[1], 1])),
    n_levels=8, n_features_per_level=2, log2_hashmap_size=16, base_resolution=np.array([1,1,1]),
    finest_resolution=np.round(np.array([dim_yx[0]*object_fine_hash_ratio*motion_hash_ratio, dim_yx[1]*object_fine_hash_ratio*motion_hash_ratio, motion_hash_temporal])))

object_mlp_param = spacetime.MLPParameters(net_depth=2, net_width=128,
                                           net_activation=nn.gelu, skip_layer=4)
motion_mlp_param = spacetime.MLPParameters(net_depth=2, net_width=32,
                                           net_activation=nn.elu, skip_layer=6)

spacetime_param = spacetime.SpaceTimeParameters(motion_mlp_param=motion_mlp_param,
                                                 object_mlp_param=object_mlp_param,
                                                 motion_embedding='hash_combined',
                                                 motion_embedding_param=hash_param_motion_spacetime,
                                                 object_embedding='hash', object_embedding_param=hash_param,
                                                 out_activation=lambda x: x)

### Data loader

In [None]:
batch_size = 4
pat_indices = np.eye(4)
time_norm = np.linspace(-1,1,4)
data_loader = cc.data_utils.loader_from_numpy({'img': img, 
                                               't': time_norm.reshape((-1)),
                                               'ind_pat': pat_indices}, prefix_dim=(batch_size,),
                                              seed=85472, )
sample_input_dict = next(data_loader)[0]
num_steps_per_epoch = len(next(data_loader))

### initialize model

In [None]:
model = dpc_flow.DPCFlow(param, s, spacetime_param, annealed_epoch=4000)

rng = jax.random.PRNGKey(0)
variables = model.init(rng, input_dict=sample_input_dict)

### Define loss function

In [None]:
l2 = cc.loss.Loss(dpc_flow.gen_loss_l2(margin=5), 'l2')
reg_absorp = cc.loss.Loss(dpc_flow.gen_l2_reg_absorp(freq_space=False), 'reg_l2_absorp', has_intermediates=True)
reg_phase = cc.loss.Loss(dpc_flow.gen_l2_reg_phase(freq_space=False), 'reg_l2_phase', has_intermediates=True)
total_loss = l2  + reg_absorp * 1e-4 + reg_phase*1e-4

### Run reconstruction

In [None]:
save_path = './checkpoint/DPC_c_elegans/'

recon_param = cc.reconstruction.ReconIterParameters(save_dir=save_path, n_epoch=5000,
                                                    keep_checkpoints=1, 
                                                    checkpoint_every=5000,
                                                    output_every=100, log_every=100)

object_mlp_params = cc.reconstruction.ReconVarParameters(lr=1e-3, opt='adam', 
                                                  opt_kwargs={'b1': 0.9, 'b2': 0.99, 'eps': 1e-15},
                                                  schedule='exponential',
                                                  schedule_kwargs={'transition_steps': 5e3, 'decay_rate': 0.1, 'transition_begin': 0},
                                                  update_every=1)
motion_mlp_params = cc.reconstruction.ReconVarParameters(lr=1e-5, opt='adam',
                                                     opt_kwargs={'b1': 0.9, 'b2': 0.99, 'eps': 1e-15},
                                                     schedule='exponential',
                                                     schedule_kwargs={'transition_steps': 5e3, 'decay_rate': 0.1, 'transition_begin': 0},
                                                     update_every=1)

var_params = {'params': {'spacetime': {'motion_mlp': motion_mlp_params, 'object_mlp': object_mlp_params,
                                       'motion_embedding': motion_mlp_params, 'object_embedding': object_mlp_params},}}

recon_variables, recon = cc.reconstruction.reconstruct_multivars_sgd(model.apply, variables, var_params,
                                                                     data_loader, total_loss, recon_param,
                                                                     None, None)

recon_t = np.array([model.apply(recon_variables, np.array([t]), np.array([[0, 0]]),
                                method=lambda module, a, b: module.spacetime(a, b))[0] for i, t in enumerate(time_norm.reshape((-1)))])


### Visualize the reconstruction

In [None]:
f, axes = plt.subplots(2, 2, figsize=(10, 4), sharex=True, sharey=True)
frames = []
margin = 20

clim_ = (np.percentile(recon_t[..., 0], 0.1), np.percentile(recon_t[..., 0], 99))
clim_phase = (np.percentile(recon_t[..., 1], 70), np.percentile(recon_t[..., 1], 99.99))

def init():
    frames.append(axes[0, 0].imshow(recon_t[0,margin:-margin,margin:-margin,0], cmap='gray', clim=clim_, interpolation='None'))
    frames.append(axes[1, 0].imshow(recon_t[0,margin:-margin,margin:-margin,1], cmap='gray', clim=clim_phase, interpolation='None'))
    frames.append(axes[0, 1].imshow(abs_baseline[margin:-margin, margin:-margin], cmap='gray',clim=clim_, interpolation='None'))
    frames.append(axes[1, 1].imshow(phase_baseline[margin:-margin, margin:-margin], cmap='gray', clim=clim_phase, interpolation='None'))
    
    axes[0,0].set_title('neural space-time model', fontsize=14)
    axes[0,1].set_title('conventional reconstruction', fontsize=14)
    axes[0,0].text(-0., 0.5, 'amplitude',
        horizontalalignment='right',
        verticalalignment='center',
        rotation='vertical',
        transform=axes[0,0].transAxes, fontsize=14)
    axes[1,0].text(-0., 0.5, 'phase',
        horizontalalignment='right',
        verticalalignment='center',
        rotation='vertical',
        transform=axes[1,0].transAxes, fontsize=14)

    axes[0,0].set_xlim([65, 765])
    [[ax.axis('off') for ax in axe] for axe in axes]
    f.tight_layout()
    
init()
def updateFrame(t):
    frames[0].set_data(recon_t[t,margin:-margin,margin:-margin,0])
    frames[1].set_data(recon_t[t,margin:-margin,margin:-margin,1])

interact(updateFrame, t=IntSlider(min=0, max=3, step=1, value=0))

## Temporal interpolation

In [None]:
recon_t_dense = np.array([model.apply(recon_variables, np.array([t]), np.array([[0, 0]]),
                                method=lambda module, a, b: module.spacetime(a, b))[0] for t in np.linspace(-1,1,100)])

In [None]:
f, axes = plt.subplots(2, 2, figsize=(10, 4), sharex=True, sharey=True)
frames, text = [], []
margin = 20

frame_number_interp = np.linspace(1, 4, 100)

clim_ = (np.percentile(recon_t[..., 0], 0.1), np.percentile(recon_t[..., 0], 99))
clim_phase = (np.percentile(recon_t[..., 1], 70), np.percentile(recon_t[..., 1], 99.99))

def init():
    frames.append(axes[0, 0].imshow(recon_t[0,margin:-margin,margin:-margin,0], cmap='gray', clim=clim_, interpolation='None'))
    frames.append(axes[1, 0].imshow(recon_t[0,margin:-margin,margin:-margin,1], cmap='gray', clim=clim_phase, interpolation='None'))
    frames.append(axes[0, 1].imshow(abs_baseline[margin:-margin, margin:-margin], cmap='gray',clim=clim_, interpolation='None'))
    frames.append(axes[1, 1].imshow(phase_baseline[margin:-margin, margin:-margin], cmap='gray', clim=clim_phase, interpolation='None'))
    
    axes[0,0].set_title('neural space-time model', fontsize=14)
    axes[0,1].set_title('conventional reconstruction', fontsize=14)
    axes[0,0].text(-0., 0.5, 'amplitude',
        horizontalalignment='right',
        verticalalignment='center',
        rotation='vertical',
        transform=axes[0,0].transAxes, fontsize=14)
    axes[1,0].text(-0., 0.5, 'phase',
        horizontalalignment='right',
        verticalalignment='center',
        rotation='vertical',
        transform=axes[1,0].transAxes, fontsize=14)

    f.tight_layout()
    
    axes[0,0].set_xlim([65, 765])

    # text
    text.append(axes[0,0].text(105.5, 55.6, f'frame {0}', color='black', fontsize=14, ))

    # color bar
    f.colorbar(frames[0], ax=axes[0], location='right', anchor=(-0.3, 0.5), shrink=0.5)
    cbar_phase = f.colorbar(frames[1], ax=axes[1], location='right', anchor=(-0.3, 0.5), shrink=0.5, ticks=[0, 1, 2, 3, 4])
    
    [[ax.axis('off') for ax in axe] for axe in axes]

init()
def updateFrameVideo(i):
    text[0].set_text('frame {:0.2f}'.format(frame_number_interp[i]))
    frames[0].set_data(recon_t_dense[i,margin:-margin,margin:-margin,0])
    frames[1].set_data(recon_t_dense[i,margin:-margin,margin:-margin,1])

    frames[2].set_data(abs_baseline[margin:-margin, margin:-margin])
    frames[3].set_data(phase_baseline[margin:-margin, margin:-margin])

interact(updateFrameVideo, i=IntSlider(min=0, max=99, step=1, value=0))

## Reference