# JAX FFN inference on LICONN data


In [None]:
# Install the latest snapshot from the FFN repository.
!pip install git+https://github.com/google/ffn

In [2]:
import os
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'

# Ensure tensorstore does not attempt to use GCE credentials
os.environ['GCE_METADATA_ROOT'] = 'metadata.google.internal.invalid'

import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')

In [3]:
from clu import checkpoint
from connectomics.jax.models import convstack
import jax
import matplotlib.pyplot as plt
import numpy as np
import tensorstore as ts

from ffn.inference import inference
from ffn.inference import inference_utils
from ffn.inference import inference_pb2
from ffn.inference import executor
from ffn.training import model as ffn_model

In [16]:
# Check for GPU presence. If this fails, use "Runtime > Change runtime type".
assert jax.devices()[0].platform in ('gpu', 'tpu')

In [17]:
# Use sample LICONN data (2x downsampled).
context = ts.Context({'cache_pool': {'total_bytes_limit': 1_000_000_000}})
img = ts.open({
    'driver': 'neuroglancer_precomputed',
    'kvstore': {'driver': 'gcs', 'bucket': 'liconn-public'},
    'path': 'ExPID82_1/image_230130b',
    'scale_index': 1},
    read=True, context=context).result()[ts.d['channel'][0]]

In [34]:
# Load a 500^3 subvolume for local processing.
x0, y0, z0 = 1100, 1083, 111
raw = img[x0:x0+500, y0:y0+500, z0:z0+500].read().result()
raw = np.transpose(raw, [2, 1, 0])
raw = (raw.astype(np.float32) - 128.0) / 33.  # normalize data for inference

In [None]:
plt.matshow(raw[250, :, :], cmap=plt.cm.Greys_r)

In [None]:
# Load sample model checkpoint.
!gsutil cp gs://liconn-public/models/ffn/axde_59110972/ckpt-2116* .

ckpt = checkpoint.Checkpoint('')
state = ckpt.load_state(state=None, checkpoint='ckpt-2116')

In [21]:
# Instantiate model for inference.
model = convstack.ResConvStack(convstack.ConvstackConfig(depth=20, padding='same', use_layernorm=True))
fov_size = 33, 33, 33
model_info = ffn_model.ModelInfo(deltas=(8, 8, 8), pred_mask_size=fov_size, input_seed_size=fov_size, input_image_size=fov_size)

@jax.jit
def _apply_fn(data):
  return model.apply({'params': state['params']}, data)

iface = executor.ExecutorInterface()
counters = inference_utils.Counters()
exc = executor.JAXExecutor(iface, model_info, _apply_fn, counters, 1)
exc.start_server()

In [39]:
options = inference_pb2.InferenceOptions(init_activation=0.95, pad_value=0.5, move_threshold=0.6, segment_threshold=0.6)
cv = inference.Canvas(model_info, exc.get_client(counters), raw, options, voxel_size_zyx=(24, 18, 18))

In [None]:
# Trace a single neurite.
pos_xyz = (123, 171, 225)
cv.segment_at(pos_xyz[::-1], dynamic_image=inference.DynamicImage(), vis_update_every=10)

In [None]:
!pip install neuroglancer

In [48]:
# Visualize results in neuroglancer.
import neuroglancer
from scipy.special import expit
from scipy.ndimage import label
seg = (label(cv.seed > 0)[0] == 1).astype(np.uint64)

In [None]:
dimensions = neuroglancer.CoordinateSpace(
    names=['x', 'y', 'z'],
    units='nm',
    scales=[18, 18, 24],
)
viewer = neuroglancer.Viewer()
with viewer.txn() as s:
  s.dimensions = dimensions
  s.layers['raw'] = neuroglancer.ImageLayer(source=neuroglancer.LocalVolume(np.transpose((raw * 33 +128).astype(np.uint8), [2, 1, 0]), dimensions))
  s.layers['seg'] = neuroglancer.SegmentationLayer(source=neuroglancer.LocalVolume(np.transpose(seg.astype(np.uint64), [2, 1, 0]), dimensions), segments=[1])

viewer