In [2]:
import os

import seaborn as sns

import numpy as np
import xarray as xr
import dask
from dask.distributed import Client, SSHCluster
from dask.diagnostics import Profiler

In [None]:
%autoawait asyncio

async def setup_dask_cluster():
    # The modern SSHCluster API is asynchronous and must be awaited.
    # The scheduler runs on the first host in the list.
    # To run multiple workers on a host, repeat its name in the list.
    preload_command = "import os; os.umask(0o000)"
    cluster = await SSHCluster(
        hosts=["panoseti-dfs0", "panoseti-dfs0", "panoseti-dfs1", "panoseti-dfs2", "panoseti"],
        connect_options={"known_hosts": None},
        worker_options={
            "nthreads": 16,
            "memory_limit": "16GB",
            # "preload": "/mnt/beegfs/worker_init.py",
            "preload": [preload_command]
        },
        scheduler_options={"port": 0, "dashboard_address": ":8797"},
    )

    # Connect to the cluster asynchronously.
    client = await Client(cluster, asynchronous=True)
    
    print(f"Dask dashboard link: {client.dashboard_link}")
    
    # Returning the client and cluster to use them later
    return client, cluster

# In a Jupyter Notebook, you can run the async function directly.
# If in a regular Python script, you would use: asyncio.run(setup_dask_cluster())
client, cluster = await setup_dask_cluster()
client

In [10]:
!pwd

/Users/nico/panoseti/panoseti_zarr


In [14]:
from pathlib import Path
import numcodecs
import zarr

l0_dir = Path('/mnt/beegfs/data/L0')
l1_dir = Path('/mnt/beegfs/data/L1')

run_dir = 'obs_Lick.start_2024-07-25T04:34:06Z.runtype_sci-data.pffd'
pff_base = "start_2024-07-25T04:34:46Z.dp_img16.bpp_2.module_1.seqno_0"

zarr_in_path = l1_dir / run_dir / f"{pff_base}.zarr"
out_path = l1_dir / run_dir / f'atype_img-medsub.{pff_base}.zarr'
print(out_path)

assert os.path.exists(zarr_in_path)
# os.makedirs(out_path, exist_ok=True)

ds = xr.open_datatree(
    zarr_in_path,
    consolidated=False,
    engine='zarr',
    chunks={},
    cache=True
)
ds.images

In [None]:
# out_path = 'img_medsub.zarr'
# out_path = l1_dir / 'img_medsub.zarr'
original_umask = os.umask(0o000)

try:
    if 'ph' in str(zarr_in_path):
        if ds.images.shape[1] == 32:
            upper_right_slice = (slice(None), slice(0, 16), slice(16, 32))
            lower_left_slice = (slice(None), slice(16, 32), slice(0, 16))
            
            # 2. Lazily extract the data from the two quadrants.
            upper_right_quadrant = ds.images[upper_right_slice]
            lower_left_quadrant = ds.images[lower_left_slice]
            
            # 3. Create a mutable copy of the array to modify.
            images_swapped = ds.images.copy()
            
            # 4. Perform the swap by assigning the extracted quadrants to the new locations.
            # The result is assigned to 'swapped' for the next stage of the pipeline.
            images_swapped[upper_right_slice] = lower_left_quadrant
            images_swapped[lower_left_slice] = upper_right_quadrant
            
            pre_preprocessed = images_swapped
            add_baselines = pre_preprocessed + 800
            pedestal = add_baselines.median(dim=['time'])
            pedestal_sub = add_baselines - pedestal
            
            # 2. Define the threshold and the boolean mask lazily
            sigma_5 = pedestal_sub.std('time') * 5
            sigma_5_above_mask = (pedestal_sub > sigma_5)
            
            # 3. Create the final masked array lazily
            # The .where() method is an idiomatic xarray/dask way to apply a boolean mask.
            # It keeps values where the mask is True and sets them to NaN (or another fill_value) elsewhere.
            ph_5_sigma_above = pedestal_sub.where(sigma_5_above_mask)
        else:
            pre_preprocessed = ds.images[:1000]
        
            add_baselines = pre_preprocessed + 800
            pedestal = add_baselines.median(dim=['time'])
            pedestal_sub = add_baselines - pedestal
            
            # 2. Define the threshold and the boolean mask lazily
            sigma_5 = pedestal_sub.std('time') * 5
            sigma_5_above_mask = (pedestal_sub > sigma_5)
            
            # 3. Create the final masked array lazily
            # The .where() method is an idiomatic xarray/dask way to apply a boolean mask.
            # It keeps values where the mask is True and sets them to NaN (or another fill_value) elsewhere.
            ph_5_sigma_above = pedestal_sub.where(sigma_5_above_mask)
        
        # Alternatively, using dask.array.ma directly:
        # ph_5_sigma_above = da.ma.masked_array(pedestal_sub, mask=sigma_5_above_mask)
        
        
        
        # 4. Compute only the final result, once.
        # This single call executes the entire optimized graph.
        store_operation = ph_5_sigma_above.to_zarr(
            'ph_5s.zarr',
            mode='w',
            consolidated=False,
            compute=False,
            zarr_format=3,
        )
    
    elif 'img' in str(zarr_in_path):
        frame_step = 100000
        
        # 1) Convert to float to avoid uint16 overflow during subtraction
        img_adc_to_pe = (ds.images.astype('float32')) / 10.0  # DataArray backed by Dask
        
        # 2) 8??8 block medians on a strided subset, then median over time
        block_8x8_medians = (
            img_adc_to_pe[::frame_step]
            .coarsen(y=8, x=8, boundary="trim")
            .median()
            .median('time')
        )  # dims: (y, x) with sizes 4x4
        
        # 3. Upsample 4x4 medians to a 32x32 Dask array using da.repeat
        upsampled_medians_da = da.repeat(block_8x8_medians.data, 8, axis=0)
        upsampled_medians_da = da.repeat(upsampled_medians_da, 8, axis=1)
    
        # Wrap in an xarray.DataArray to restore coordinates for broadcasting
        upsampled_medians = xr.DataArray(
            upsampled_medians_da,
            dims=('y', 'x'),
            coords={'y': ds.images.y, 'x': ds.images.x},
        )
        
        # 4) Subtract spatial 8x8 medians via broadcasting (no map_blocks needed)
        median_subtraction_8x8 = img_adc_to_pe - upsampled_medians  # (time, y, x)
        
        # 5. Compute 32x32 medians, then subtract
        # median_subtraction_8x8 = img_adc_to_pe
        supermedian_img = median_subtraction_8x8[::frame_step].median('time')
        median_subtraction_final = median_subtraction_8x8 - supermedian_img
        # median_subtraction_final = median_subtraction_8x8
    
        # 6. Define the desired compressor using numcodecs
        # Zstd is an excellent modern compressor balancing speed and ratio.
        # Level 10 is a high compression setting.
        compressor = zarr.codecs.ZstdCodec(level=15)
    
        # 7. Create the encoding dictionary
        # This tells xarray to use your chosen compressor for the data variable.
        # Using 'median_subtraction.name' makes the code robust if the variable name changes.
        median_subtraction_final.name = 'median_subtracted_data'
        encoding = {
            median_subtraction_final.name: {"compressors": [compressor]}
        }
        
        # 8. Call to_zarr with the encoding
        # It's also best practice to set consolidated=True for faster read performance.
        store_operation = median_subtraction_final.to_zarr(
            out_path,
            mode='w',
            consolidated=True,
            encoding=encoding,
            compute=False,
            zarr_format=3
        )
    result = store_operation.compute()
    display(result)
finally:
    os.umask(original_umask)

In [15]:
# view processed data
ds_clean = xr.open_datatree(
    out_path,
    consolidated=False,
    engine='zarr',
    chunks={},
    # cache=True
)

ds_clean.images

Unnamed: 0,Array,Chunk
Bytes,32.25 MiB,32.25 MiB
Shape,"(4128, 32, 32)","(4128, 32, 32)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 32.25 MiB 32.25 MiB Shape (4128, 32, 32) (4128, 32, 32) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",32  32  4128,

Unnamed: 0,Array,Chunk
Bytes,32.25 MiB,32.25 MiB
Shape,"(4128, 32, 32)","(4128, 32, 32)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
