# BigWarp .zarr to OME-TIFF

## Setting up a conda environment

```bash
conda env create -f environment.yml
conda activate zarr-ometiff
```

## Import packages

In [None]:
import os
import zarr
from pathlib import Path
import matplotlib.pyplot as plt
from skimage.transform import downscale_local_mean
from typing import Dict, List, Optional, Tuple, Iterator
import numpy as np
import tifffile
import math
import dask.array as da

## Define functions

In [None]:
def _zarr_image_get_pixelsize_um(root, group_name) -> Tuple[float, float]:
    # As an example a transform attribute for an image channel in a Bigwarp zarr:
    # {'axes': ['y', 'x'], 'scale': [0.0002495152876362007, 0.0002495152876362007, 0.001], 'translate': [0.0, 0.0, 0.0], 'units': ['mm', 'mm']}
    # It is inconsistent because sometimes it specifies three dimensions, and sometimes two.
    transform = root[group_name].attrs['transform']
    print(f'Image transform {transform}')

    assert transform['axes'] == ['y', 'x']
    assert transform['translate'] == [0.0, 0.0, 0.0]
    assert transform['units'] == ['mm', 'mm']

    scale = transform['scale']
    assert len(scale) == 3  # the 3rd dimensions is probably a default z-plane distance of 1 micron

    pixel_size_y_um = scale[0] * 1000.0 
    pixel_size_x_um = scale[1] * 1000.0

    return (pixel_size_y_um, pixel_size_x_um)

In [None]:
# Define function
def write_pyramidal_ome_tiff(img_stack: da.Array,                # dask array image stack of shape (channels, y, x)
                             pyramid_filename: str,
                             channel_names: Optional[List[str]],
                             pixel_size_um: Optional[Tuple[float, float]],        # pixel size in microns, or None
                             compression: Optional[str],         # 'zlib' or None
                             tile_size: int,                     # tile size (in pixels) in output OME TIFF file
                             max_levels: int,
                             downsample_method: str) -> None:    # downsample method = 'box' (better quality, slower) or 'nearest neighbor' (faster, poorer quality)

    num_pyramid_levels: int = max_levels  # CHECKME: do we need to clip this value in case we would end up with <= 1 pixel images in the pyramid?

    ome_metadata = _make_ome_metadata(img_stack, channel_names, pixel_size_um)

    tile_sizes = (tile_size, tile_size)

    options = dict(tile=tile_sizes,
                   photometric='minisblack',
                   compression=compression,
                   metadata=ome_metadata,
                   software=_creator())
    
    num_channels, image_height, image_width = img_stack.shape

    # Create output folder if it does not exist yet.
    Path(pyramid_filename).parent.mkdir(parents=True, exist_ok=True)

    print(f'Writing pyramidal OME TIFF file {pyramid_filename} (compression: {options["compression"]})')
    with tifffile.TiffWriter(pyramid_filename, ome=True, bigtiff=True) as tif:
        
        # Write full resolution image
        print(f'Writing level 0: {image_width} x {image_height} px')
        downsample_factor = 1
        tif.write(data=_tiles_generator(img_stack, tile_sizes, downsample_factor, downsample_method),
                  shape=(num_channels, image_height, image_width, 1),
                  dtype=img_stack.dtype,
                  subifds=num_pyramid_levels-1,
                  **options)

        # Save downsampled pyramid images to the subifds
        for level in range(1, num_pyramid_levels):
            downsampled_data_shape = math.ceil(image_height / (2**level)), math.ceil(image_width / (2**level))
            print(f'Writing level {level}: {downsampled_data_shape[1]} x {downsampled_data_shape[0]} px')
            downsample_factor = 2**level
            tif.write(data=_tiles_generator(img_stack, tile_sizes, downsample_factor, downsample_method), 
                      shape=(num_channels, downsampled_data_shape[0], downsampled_data_shape[1], 1),
                      dtype=img_stack.dtype,
                      subfiletype=1,
                      **options)


def _tiles_generator(img_stack: da.Array,
                     tile_sizes: Tuple[int, int],
                     downsample_factor: int,
                     downsample_method: str) -> Iterator[np.ndarray]:
    # See also https://forum.image.sc/t/tifffile-ome-tiff-generation-is-taking-too-much-ram/41865/16
    # and https://github.com/labsyspharm/ashlar/blob/5bf5b8710f456e68e33ff232708cda0b1c904a33/ashlar/reg.py
    tile_height, tile_width = tile_sizes
    num_channels = img_stack.shape[0]
    for c in range(num_channels):
        print(f'  channel {c+1}/{num_channels}')

        # Get a full numpy array image from the dask array.
        # This loses dask chunk advantages (limited peak memory consumption), but avoids tiling artefacts in our naive downsampling implementation.
        # It also makes our retiling (from dask chunk size to OME TIFF tile size) trivial.
        # IMPROVEME: for downsampling of the original dask arrays we may get some inspiration here:
        # https://github.com/spatial-image/multiscale-spatial-image/blob/0c6f65cdc69cb069e81cdc07e7f3f5441f0cc4e5/multiscale_spatial_image/to_multiscale/_dask_image.py#L100
        # However, we may still need to retile afterwards.
        image = img_stack[c].compute()  

        # Downsample
        image = _downsample_image(image, downsample_factor, downsample_method)

        # Generate tiles
        image_height, image_width = image.shape
        for y in range(0, image_height, tile_height):
            for x in range(0, image_width, tile_width):
                yield image[y:y+tile_height, x:x+tile_width].copy()


def _downsample_image(img: np.ndarray,
                      downsample_factor: int,
                      downsample_method: str) -> np.ndarray:

    if downsample_factor != 1:
        if downsample_method == 'box':
            # Box filter. This yields dramatically better quality than nearest neighbor but is slower.
            img = downscale_local_mean(img, (downsample_factor, downsample_factor)).astype(img.dtype)
        else:
            # Nearest neighbor downsampling. Fast but poor quality.
            img = img[::downsample_factor, ::downsample_factor]

    return img


def _make_ome_metadata(img_stack: da.Array,
                       channel_names: Optional[List[str]],
                       pixel_size_um: Optional[Tuple[float, float]]):

    # Collect OME metadata
    ome_metadata = {}
    ome_metadata['Creator'] = _creator()
    if pixel_size_um:
        pixel_size_y_um, pixel_size_x_um = pixel_size_um
        print(f'Pixel size: {pixel_size_um} micrometer')
        ome_metadata['PhysicalSizeX'] = str(pixel_size_x_um)
        ome_metadata['PhysicalSizeXUnit'] = 'µm'
        ome_metadata['PhysicalSizeY'] = str(pixel_size_y_um)
        ome_metadata['PhysicalSizeYUnit'] = 'µm'

    if channel_names:
        print(f'Channel names: {channel_names}')
        assert len(channel_names) == img_stack.shape[0]
        ome_metadata['Channel'] = {'Name': channel_names}

    return ome_metadata    


def _creator() -> str:
    return f'zarr-ometiff'


In [None]:
def group_name_to_channel_name(group_name: str, renaming_dict: Optional[Dict[str, str]]=None) -> str:
   # A group name look like this: '/warped/001_AntigenCycle_Donkey_anti-goat-FITC_V50_FITC_16bit_M-20x-S_Fluor_full_sensor_B-1_R-6_W-0_G-1_F-1_E-1200.0.tif-Donkey'
    # we extract the channel name as the part after the second'/'.
    channel_name = group_name.split('/')[2]

    # Optionally rename channel
    if renaming_dict:
        channel_name = renaming_dict[channel_name]

    return channel_name

## Open BigWarp .zarr

In [None]:
# Open zarr
zarr_path = r'D:\Frank\BigWarp-Troubleshooting-MarineLab\bigwarp.zarr'
assert Path.exists(Path(zarr_path))

root = zarr.open(zarr_path, 'r')
# NOTE: Set path to your BigWarp Zarr

In [None]:
# Check contents of Zarr
root.tree()

In [None]:
# Get the name of the first group in the zarr (in the case of bigwarp zarr there is only one)
group_key = list(root)[0]
group = root[group_key]
group

In [None]:
channel_group_names = [item[1].name for item in group.items()]
channel_group_names

## Visual exploration of the zarr

In [None]:
# Read the first channel as an example
img = root[channel_group_names[0]]
img

In [None]:
# The physical pixel size can be recovered from the zarr attributes dictionary associated with the zarr group holding the (single channel) image.
img.attrs['transform']

In [None]:
# Plot image
plt.imshow(img[::4,::4])

## Pyramidal OME-TIFF Writing

## Save zarr as single channel OME-TIFFs

In [None]:
# Set parameters
tile_size: int = 1024
max_levels: int = 6         # number of image pyramid levels
downsample_method = 'box'
compression = 'zlib'

In [None]:
# Define dictionary for mapping BigWarp channel names back to the original filenames.
orig_filenames = {
    'cycle2-s0-DAPI' : 'some_other_filename_DAPI',
    'cycle2-s0-Opal_520' : 'some_other_filename_Opal_520',
    'cycle2-s0-Opal_570' : 'some_other_filename_Opal_570',
    'cycle2-s0-Opal_690' : 'some_other_filename_Opal_690',
    'cycle2-s0-Opal_780' : 'some_other_filename_Opal_780',
    'cycle2-s0-Sample_AF' : 'some_other_filename_Sample_AF'
}
# NOTE: Change these names to the BigWarp channel names and the desired output names.

In [None]:
# Specify path to output folder
output_folder = r'D:\Frank\BigWarp-Troubleshooting-MarineLab\registered'
# NOTE: Specify the folder where you want to save the output files.

In [None]:
# Convert zarr to single-channels OME-TIFFs
for channel, channel_group_name in enumerate(channel_group_names):
    channel_name = group_name_to_channel_name(channel_group_name, orig_filenames)

    print(f'Channel {channel}: {channel_group_name} -> {channel_name}')

    pyramid_filename = os.path.join(output_folder, f'{channel_name}.tif')

    img_channel = da.expand_dims(da.array(root[channel_group_name]), axis=0)

    pixel_size_um: Tuple[float, float] = _zarr_image_get_pixelsize_um(root, f'{channel_group_name}')

    write_pyramidal_ome_tiff(img_channel,
                             pyramid_filename,
                             [channel_name],
                             pixel_size_um,
                             compression,
                             tile_size,
                             max_levels,
                             downsample_method)