# Phase Unwrapping

> phase unwrapping

In [None]:
#| default_exp cli/pu

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
import moraine.cli as mc
import zarr
import numpy as np
import toml
import holoviews as hv
from bokeh.models import WheelZoomTool

In [None]:
hv.extension('bokeh')
hv.output(widget_location='bottom')

In [None]:
#| export
import logging
import zarr
import time
import numpy as np

import dask
from dask import array as da
from dask import delayed
from dask.distributed import Client, LocalCluster, progress
import moraine as mr
from moraine.cli.logging import mc_logger
from moraine.cli import dask_from_zarr, dask_to_zarr, parallel_read_zarr

In [None]:
#| export
@mc_logger
def gamma_mcf_pt(
    pc_x:str, # x coordinate, shape of (N,)
    pc_y:str, # y coordinate, shape of (N,)
    ph:str, # stack of wrapped phase, shape of (N,M)
    unw_ph:str, # output, unwrapped phase, shape of (N,L)
    image_pairs:np.ndarray,# image pairs to construct interferograms for unwrapping
    ref_point:int=1, # reference point, the first point by default
    out_chunks:int=None, # unw_ph point cloud chunk size, same as ph by default
    n_workers=1, # number of dask worker, number of interferograms to be unwrapped in the same time
    threads_per_worker=2, # number of threads per dask worker
    **dask_cluster_arg, # other dask local/cudalocal cluster args
):
    '''A wrapper for mcf_pt in GAMMA software.'''
    
    logger = logging.getLogger(__name__)
    logger.info('load coordinates')
    pc_x_data = parallel_read_zarr(zarr.open(pc_x,mode='r'),(slice(None),))
    pc_y_data = parallel_read_zarr(zarr.open(pc_y,mode='r'),(slice(None),))
    logger.info('Done')

    ph_path = ph
    unw_ph_path = unw_ph

    ph_zarr = zarr.open(ph_path,mode='r')
    logger.zarr_info(ph_path,ph_zarr)
    npoint, nimage = ph_zarr.shape
    nimage_pairs = image_pairs.shape[0]
    
    if out_chunks is None: out_chunks = ph_zarr.chunks[0]

    Cluster = LocalCluster; cluster_args = {'processes':True, 'n_workers':n_workers, 'threads_per_worker':threads_per_worker}
    cluster_args.update(dask_cluster_arg)
    
    logger.info('starting dask local cluster.')
    with Cluster(**cluster_args) as cluster, Client(cluster) as client:
        logger.info('dask local cluster started.')
        logger.dask_cluster_info(cluster)
        

        ph = dask_from_zarr(ph_path,chunks=(ph_zarr.shape[0],1))
        logger.darr_info('ph', ph)

        pc_x = da.from_array(pc_x_data,chunks=pc_x_data.shape)
        pc_y = da.from_array(pc_y_data,chunks=pc_y_data.shape)

        logger.info(f'phase wrapping with mcf.')
    
        pc_x_delayed = pc_x.to_delayed()[0]
        pc_y_delayed = pc_y.to_delayed()[0]
        ph_delayed = ph.to_delayed()[0]

        unw_ph_delayed = np.empty((1,nimage_pairs),dtype=object)
        f_mcf_delayed = delayed(mr.gamma_mcf_pt,pure=True,nout=1)
        f_intf_delayed = delayed(mr.intf,pure=True,nout=1)
        for i, (ref, sec) in enumerate(image_pairs):
            intf_delayed = f_intf_delayed(ph_delayed[ref],ph_delayed[sec])
            unw_ph_delayed[0,i] = f_mcf_delayed(pc_x_delayed, pc_y_delayed, intf_delayed)
            unw_ph_delayed[0,i] = da.from_delayed(unw_ph_delayed[0,i],shape=(npoint,1),meta=np.array((),dtype=np.float32))
        unw_ph = da.block(unw_ph_delayed.tolist())

        logger.info('got unwrapped phase.')
        logger.darr_info('unw_ph', unw_ph)
        logger.info('save unw_ph')
        _unw_ph = dask_to_zarr(unw_ph, unw_ph_path,chunks=(out_chunks,1))
        
        logger.info('computing graph setted. doing all the computing.')
        futures = client.persist(_unw_ph)
        progress(futures,notebook=False)
        time.sleep(0.1)
        da.compute(futures)
        logger.info('computing finished.')
    logger.info('dask cluster closed.')

In [None]:
#| export
@mc_logger
def mcf_pc(
    gix:str, # grid index, shape of (N, 2), int
    ph:str, # stack of wrapped phase, shape of (N,M)
    unw_ph:str, # output, unwrapped phase, shape of (N,L)
    image_pairs:np.ndarray,# image pairs to construct interferograms for unwrapping
    out_chunks:int=None, # unw_ph point cloud chunk size, same as ph by default
    n_workers=1, # number of dask worker, number of interferograms to be unwrapped in the same time
    threads_per_worker=2, # number of threads per dask worker
    **dask_cluster_arg, # other dask local/cudalocal cluster args
):
    '''A wrapper for mcf_pt in GAMMA software.'''
    
    logger = logging.getLogger(__name__)
    logger.info('load coordinates')
    gix_data = parallel_read_zarr(zarr.open(gix, mode='r'),(slice(None),slice(None)))
    pc_x_data, pc_y_data = gix_data[:,1], gix_data[:,0]
    # pc_x_data = parallel_read_zarr(zarr.open(pc_x,mode='r'),(slice(None),))
    # pc_y_data = parallel_read_zarr(zarr.open(pc_y,mode='r'),(slice(None),))
    logger.info('Done')

    ph_path = ph
    unw_ph_path = unw_ph

    ph_zarr = zarr.open(ph_path,mode='r')
    logger.zarr_info(ph_path,ph_zarr)
    npoint, nimage = ph_zarr.shape
    nimage_pairs = image_pairs.shape[0]
    
    if out_chunks is None: out_chunks = ph_zarr.chunks[0]

    logger.info('construct Delaunay triangulation and mcf solver')
    required_data = mr.pu._prepare_mcf(pc_x_data, pc_y_data)
    logger.info('Done')
    
    Cluster = LocalCluster; cluster_args = {'processes':True, 'n_workers':n_workers, 'threads_per_worker':threads_per_worker}
    cluster_args.update(dask_cluster_arg)
    
    logger.info('starting dask local cluster.')
    with Cluster(**cluster_args) as cluster, Client(cluster) as client:
        logger.info('dask local cluster started.')
        logger.dask_cluster_info(cluster)

        required_data_future = [client.scatter(data, broadcast=True) for data in required_data]
        ph = dask_from_zarr(ph_path,chunks=(ph_zarr.shape[0],1))
        logger.darr_info('ph', ph)
        # ph_delayed = ph.to_delayed()[0]
        logger.info(f'phase wrapping with mcf.')

        unw_ph_delayed = np.empty((1,nimage_pairs),dtype=object)
        f_mcf_delayed = delayed(mr.pu._solve_mcf,pure=True,nout=1)
        f_intf_delayed = delayed(mr.intf,pure=True,nout=1)
        for i, (ref, sec) in enumerate(image_pairs):
            ref_ph_delayed = ph[:,ref].to_delayed()[0]
            sec_ph_delayed = ph[:,sec].to_delayed()[0]
            intf_delayed = f_intf_delayed(ref_ph_delayed,sec_ph_delayed)
            # intf_delayed = f_intf_delayed(ph_delayed[ref],ph_delayed[sec])
            unw_ph_delayed[0,i] = f_mcf_delayed(intf_delayed, *required_data_future)
            unw_ph_delayed[0,i] = da.from_delayed(unw_ph_delayed[0,i],shape=(npoint,),meta=np.array((),dtype=np.float32)).reshape(npoint,1)
        unw_ph = da.block(unw_ph_delayed.tolist())

        logger.info('got unwrapped phase.')
        logger.darr_info('unw_ph', unw_ph)
        logger.info('save unw_ph')
        _unw_ph = dask_to_zarr(unw_ph, unw_ph_path,chunks=(out_chunks,1))
        
        logger.info('computing graph setted. doing all the computing.')
        futures = client.persist(_unw_ph)
        progress(futures,notebook=False)
        time.sleep(0.1)
        da.compute(futures)
        logger.info('computing finished.')
    logger.info('dask cluster closed.')

Usage:

In [None]:
logger = mc.get_logger()

In [None]:
# load phase-linked wrapped phase
ds_ph = './pu/ds_ph.zarr/'
ds_gix = './pu/ds_gix.zarr/'
ds_unw = './pu/ds_unw.zarr/'
ds_ph_zarr = zarr.open(ds_ph,mode='r')
tnet = mr.TempNet.from_bandwidth(ds_ph_zarr.shape[1],bandwidth=1)

In [None]:
mcf_pc(ds_gix, ds_ph, ds_unw, tnet.image_pairs)
## or if you have gamma access
# ds_e = './pu/ds_e.zarr/'
# ds_n = './pu/ds_n.zarr/'
# gamma_mcf_pt(ds_e, ds_n, ds_ph, ds_unw, tnet.image_pairs)

2025-11-02 16:57:17 - log_args - INFO - running function: mcf_pc
2025-11-02 16:57:17 - log_args - INFO - fetching args:
2025-11-02 16:57:17 - log_args - INFO - gix = './pu/ds_gix.zarr/'
2025-11-02 16:57:17 - log_args - INFO - ph = './pu/ds_ph.zarr/'
2025-11-02 16:57:17 - log_args - INFO - unw_ph = './pu/ds_unw.zarr/'
2025-11-02 16:57:17 - log_args - INFO - image_pairs = array([[ 0,  1],
       [ 1,  2],
       [ 2,  3],
       [ 3,  4],
       [ 4,  5],
       [ 5,  6],
       [ 6,  7],
       [ 7,  8],
       [ 8,  9],
       [ 9, 10],
       [10, 11],
       [11, 12],
       [12, 13],
       [13, 14],
       [14, 15],
       [15, 16]], dtype=int32)
2025-11-02 16:57:17 - log_args - INFO - out_chunks = None
2025-11-02 16:57:17 - log_args - INFO - n_workers = 1
2025-11-02 16:57:17 - log_args - INFO - threads_per_worker = 2
2025-11-02 16:57:17 - log_args - INFO - dask_cluster_arg = {}
2025-11-02 16:57:17 - log_args - INFO - fetching args done.
2025-11-02 16:57:17 - mcf_pc - INFO - load c

In [None]:
# note that the data is already in hilbert order
mc.pc_pyramid(
    './pu/ds_ph.zarr',
    './pu/ds_ph_geo_pyramid',
    x = './pu/ds_e.zarr/',
    y = './pu/ds_n.zarr/',
    ras_resolution=20,
)
mc.pc_pyramid(
    './pu/ds_unw.zarr',
    './pu/ds_unw_geo_pyramid',
    x = './pu/ds_e.zarr/',
    y = './pu/ds_n.zarr/',
    ras_resolution=20,
)

2025-11-02 16:58:08 - log_args - INFO - running function: pc_pyramid
2025-11-02 16:58:08 - log_args - INFO - fetching args:
2025-11-02 16:58:08 - log_args - INFO - pc = './pu/ds_ph.zarr'
2025-11-02 16:58:08 - log_args - INFO - out_dir = './pu/ds_ph_geo_pyramid'
2025-11-02 16:58:08 - log_args - INFO - x = './pu/ds_e.zarr/'
2025-11-02 16:58:08 - log_args - INFO - y = './pu/ds_n.zarr/'
2025-11-02 16:58:08 - log_args - INFO - yx = None
2025-11-02 16:58:08 - log_args - INFO - ras_resolution = 20
2025-11-02 16:58:08 - log_args - INFO - ras_chunks = (256, 256)
2025-11-02 16:58:08 - log_args - INFO - pc_chunks = 65536
2025-11-02 16:58:08 - log_args - INFO - processes = False
2025-11-02 16:58:08 - log_args - INFO - n_workers = 1
2025-11-02 16:58:08 - log_args - INFO - threads_per_worker = 2
2025-11-02 16:58:08 - log_args - INFO - dask_cluster_arg = {}
2025-11-02 16:58:08 - log_args - INFO - fetching args done.
2025-11-02 16:58:08 - pc_pyramid - INFO - clean out dir
2025-11-02 16:58:09 - zarr_in

In [None]:
with open('./raw/meta.toml','r') as f:
    dates = toml.load(f)['dates']
ds_geo_intf_plot = mc.pc_plot('./pu/ds_ph_geo_pyramid',post_proc_ras='intf_seq', post_proc_pc='intf_seq',level_increase=0)
ds_geo_intf_plot = ds_geo_intf_plot[0]*ds_geo_intf_plot[1]
ds_geo_intf_plot = ds_geo_intf_plot.redim(
    i=hv.Dimension('i', label='Intf index', range=(0,len(dates)-2), value_format=(lambda i: dates[i]+'_'+dates[i+1])),
    x=hv.Dimension('lon', label='Longitude'),
    y=hv.Dimension('lat',label='Latitude'),
    z=hv.Dimension('Wrapped Phase',range=(-np.pi,np.pi))
)
ds_geo_intf_plot = ds_geo_intf_plot.opts(
    hv.opts.Image(
        cmap='colorwheel',frame_width=500, frame_height=400, colorbar=True,
        default_tools=['pan',WheelZoomTool(zoom_on_axis=False),'save','reset','hover'],
        active_tools=['wheel_zoom'],
        title="Wrapped Phase",
    ),
    hv.opts.Points(
        color='Wrapped Phase', cmap='colorwheel',frame_width=500, frame_height=400, colorbar=True,
        default_tools=['pan',WheelZoomTool(zoom_on_axis=False),'save','reset','hover'],
        active_tools=['wheel_zoom'],
        title="Wrapped Phase",
    ),
)

ds_geo_unw_plot = mc.pc_plot('./pu/ds_unw_geo_pyramid',level_increase=0)
ds_geo_unw_plot = ds_geo_unw_plot[0]*ds_geo_unw_plot[1]
ds_geo_unw_plot = ds_geo_unw_plot.redim(
    i=hv.Dimension('i', label='Intf index', range=(0,len(dates)-2), value_format=(lambda i: dates[i]+'_'+dates[i+1])),
    x=hv.Dimension('lon', label='Longitude'),
    y=hv.Dimension('lat',label='Latitude'),
    z=hv.Dimension('Unwrapped Phase',range=(-10,10))
)
ds_geo_unw_plot = ds_geo_unw_plot.opts(
    hv.opts.Image(
        cmap='colorwheel',frame_width=500, frame_height=400, colorbar=True,
        default_tools=['pan',WheelZoomTool(zoom_on_axis=False),'save','reset','hover'],
        active_tools=['wheel_zoom'],
        title="Unwrapped Phase",
    ),
    hv.opts.Points(
        color='Unwrapped Phase', cmap='colorwheel',frame_width=500, frame_height=400, colorbar=True,
        default_tools=['pan',WheelZoomTool(zoom_on_axis=False),'save','reset','hover'],
        active_tools=['wheel_zoom'],
        title="Unwrapped Phase",
    ),
)

In [None]:
hv.element.tiles.EsriImagery()*(ds_geo_intf_plot+ds_geo_unw_plot)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()