In [None]:
from zephir.methods import *
from zephir.methods.plot_loss_maps import plot_loss_maps
from zephir.methods.get_optimization_trajectory import get_optimization_trajectory
from zephir.models.container import Container
from zephir.utils.utils import *

%matplotlib inline

In [None]:
# path to dataset
dataset = Path('.')

In [None]:
# defining variable container with some key arguments
container = Container(
    dataset=dataset,
    allow_rotation=False,
    channel=1,
    dev='cpu',
    exclude_self=True,
    exclusive_prov=None,
    gamma=2,
    include_all=True,
    n_frame=1,
    z_compensator=4.0,
)

In [None]:
# loading and handling annotations
container, results = build_annotations(
    container=container,
    annotation=None,
    t_ref=None,
    wlid_ref=None,
    n_ref=None,
)

In [None]:
# compiling models
container, zephir, zephod = build_models(
    container=container,
    dimmer_ratio=0.1,
    grid_shape=(5, 25, 25),
    fovea_sigma=(1, 2.5, 2.5),
    n_chunks=10,
)

In [None]:
# compiling spring network
container = build_springs(
    container=container,
    load_nn=False,
    nn_max=5,
    verbose=True,
)

In [None]:
# define three frames, including one whose losses are visualized,
# a parent to initialize keypoint optimization trajectory,
# and a reference to sample registration target descriptors from.
# parent and reference frames should be fully annotated
frame_to_visualize = 600
parent = 640
reference = 498

In [None]:
# optimizes keypoint coordinates for t=frame_to_optimize, initialized at parent
# and saves the trajectory over optimization epochs
trajectory = get_optimization_trajectory(
    frame_to_optimize=frame_to_visualize,
    parent=parent,
    reference=reference,
    container=container,
    results=results,
    zephir=zephir,
    zephod=zephod,
    clip_grad=1.0,
    lambda_t=-1,
    lambda_d=1.0,
    lambda_n=1.0,
    lambda_n_mode='norm',
    lr_ceiling=0.2,
    lr_coef=2.0,
    lr_floor=0.01,
    n_epoch=40,
    n_epoch_d=10,
)

In [None]:
# plot loss maps for a single keypoint
# top row: centered at optimized keypoint coordinate
# bottom row: centered at manual annotation for the keypoint
# left to right: crop of input volume, registration loss (L_R),
# spatial regularization (L_N), feature detection loss (L_D),
# sum of the three losses (L_R+L_N+L_D)
losses_at_trajectory, losses_at_annotation = plot_loss_maps(
    keypoint_to_visualize=40,
    frame_to_visualize=frame_to_visualize,
    reference=reference,
    map_resolution=50,
    map_size=20,
    trajectory=trajectory,
    container=container,
    dimmer_ratio=0.1,
    grid_shape=(5, 25, 25),
    fovea_sigma=(1, 2.5, 2.5),
    n_chunks=10,
    zephod=zephod,
    lambda_d=1.0,
    lambda_n=1.0,
    lambda_n_mode='norm',
)