# Generate Datasets

In [None]:
!pip install piscis

In [None]:
import hashlib
import json
import nd2
import numpy as np
import requests
import xarray as xr

from huggingface_hub import HfFileSystem
from jax import random
from pathlib import Path

from piscis.data import generate_dataset
from piscis.paths import HF_DATASETS_DIR
from piscis.utils import fit_coords, pad_and_stack, remove_duplicate_coords, snap_coords

In [None]:
# Define path to outputs.
outputs_path = Path().absolute().parent / 'outputs'
outputs_path.mkdir(parents=True, exist_ok=True)

# Define path to datasets.
datasets_path = outputs_path / 'datasets'
datasets_path.mkdir(parents=True, exist_ok=True)

# Define path to Piscis datasets.
piscis_datasets_path = datasets_path / 'piscis'
piscis_datasets_path.mkdir(parents=True, exist_ok=True)

# Define path to deepBlink datasets.
deepblink_datasets_path = datasets_path / 'deepblink'
deepblink_datasets_path.mkdir(parents=True, exist_ok=True)

### Generate datasets from NimbusImage.

In [None]:
# Download raw images and annotations from Hugging Face.
dataset_path = piscis_datasets_path / '20230905'
hs = HfFileSystem()
hs.download(f'{HF_DATASETS_DIR}20230905/raw', str(dataset_path / 'raw'), recursive=True)

In [None]:
# Define path to raw images and annotations.
images_path = dataset_path / 'raw' / 'images'
annotations_path = dataset_path / 'raw' / 'annotations'

# Define path to save the combined dataset.
combined_path = dataset_path / 'combined'
combined_path.mkdir(parents=True, exist_ok=True)

# Define path to save datasets with DAPI.
with_dapi_path = dataset_path / 'with_dapi'
with_dapi_path.mkdir(parents=True, exist_ok=True)

In [None]:
# # List all datasets.
# datasets_list = [file.stem for file in images_path.glob('*.nd2')]

# # Loop over datasets.
# for dataset in datasets_list:

#     # Load ND2 image.
#     image = nd2.imread(images_path / f'{dataset}.nd2', dask=True, xarray=True)
#     channel_names = tuple(image.coords['C'].to_numpy())
#     dapi_channel_index = next((i for i, name in enumerate(channel_names) if name.startswith('DAPI')), None)
#     image.coords['channel_names'] = image.coords['C']
#     for axis, size in image.sizes.items():
#         image.coords[axis] = np.arange(size)
#     image.attrs['axes_calibration'] = image.attrs.pop('metadata')['metadata'].channels[0].volume.axesCalibration

#     # Load annotations.
#     with open(annotations_path / f'{dataset}.json') as f:
#         annotations = json.load(f)['annotations']

#     # Group point annotations by channel and frame.
#     frames_dict = {'C': set(), 'P': set(), 'Z': set(), 'T': set()}
#     if dapi_channel_index is not None:
#         frames_dict['C'].add(dapi_channel_index)
#     for annotation in annotations:
#         if annotation['shape'] == 'point':
#             location = annotation['location']
#             channel_index = annotation['channel']
#             frames_dict['C'].add(channel_index)
#             frames_dict['P'].add(location['XY'])
#             frames_dict['Z'].add(location['Z'])
#             frames_dict['T'].add(location['Time'])

#     # Save subsetted image.
#     image = image.isel(**{k: list(frames_dict[k]) for k in image.coords if k not in ['channel_names', 'Y', 'X']})
#     image.to_netcdf(images_path / f'{dataset}.nc', format='NETCDF4')

In [None]:
# Define axes order.
axes_order = ('P', 'Z', 'T', 'C')

# List all datasets.
datasets_list = [file.stem for file in images_path.glob('*.nc')]

# Loop over datasets.
for dataset in datasets_list:

    # Convert the dataset name to a MD5 hash.
    hashed = hashlib.md5(dataset.encode()).hexdigest()
    
    # Convert the first 8 characters of the hash to an integer to create a seed.
    seed = int(hashed[:8], 16)
    
    # Generate a random key.
    key = random.PRNGKey(seed)

    # Load ND2 metadata and image.
    image = xr.load_dataarray(images_path / f'{dataset}.nc')
    channel_names = tuple(image.coords['channel_names'].to_numpy())
    dapi_channel_index = next((i for i, name in enumerate(channel_names) if name.startswith('DAPI')), None)

    # Load annotations.
    with open(annotations_path / f'{dataset}.json') as f:
        annotations = json.load(f)['annotations']

    # Group point annotations by channel and frame.
    coords_dict = {}
    for annotation in annotations:
        if annotation['shape'] == 'point':
            location = annotation['location']
            channel_index = np.argwhere(image.coords['C'].to_numpy() == annotation['channel'])[0, 0]
            coordinates = annotation['coordinates'][0]
            frame_index = []
            for axis in image.sizes:
                if axis == 'P':
                    frame_index.append(np.argwhere(image.coords['P'].to_numpy() == location['XY'])[0, 0])
                elif axis == 'Z':
                    frame_index.append(np.argwhere(image.coords['Z'].to_numpy() == location['Z'])[0, 0])
                elif axis == 'T':
                    frame_index.append(np.argwhere(image.coords['T'].to_numpy() == location['Time'])[0, 0])
                elif axis == 'C':
                    frame_index.append(channel_index)
            frame_index = tuple(frame_index)
            coords_dict.setdefault(channel_index, {})
            coords_dict[channel_index].setdefault(frame_index, [])
            coords_dict[channel_index][frame_index].append((coordinates['y'] - 0.5, coordinates['x'] - 0.5))

    # Loop over annotations by channel.
    for channel_index, channel_coords_dict in coords_dict.items():

        # Create empty lists.
        fish_images_list = []
        with_dapi_images_list = []
        coords_list = []

        # Loop over annotations by frame.
        for frame_index, coords in channel_coords_dict.items():

            # Obtain FISH image.
            fish_image = image[frame_index].compute()

            if dapi_channel_index is not None:

                # Construct index for DAPI frame.
                dapi_index = (*frame_index[:-1], dapi_channel_index)

                # Obtain DAPI image.
                dapi_image = image[dapi_index].compute()

            # Process coordinates by snapping, fitting, and removal of duplicates.
            coords = np.array(coords)
            coords = snap_coords(coords, fish_image)
            coords = fit_coords(coords, fish_image)
            coords = remove_duplicate_coords(coords)

            # Add images and processed coordinates.
            fish_images_list.append(fish_image)
            with_dapi_images_list.append(np.stack((fish_image, dapi_image)))
            coords_list.append(coords)

        # Generate dataset.
        generate_dataset(dataset_path / f'{dataset}_{channel_names[channel_index]}.npz', fish_images_list, coords_list, key)

        # Generate dataset with DAPI if possible.
        if dapi_channel_index is not None:
            generate_dataset(with_dapi_path / f'{dataset}_{channel_names[channel_index]}_with_dapi.npz', with_dapi_images_list, coords_list, key)

### Subset deepBlink datasets.

In [None]:
# Define the URL for the Figshare API.
api_url = f'https://api.figshare.com/v2/articles/12958037'

# Get a list of files from Figshare.
files = requests.get(api_url).json()['files']

for file in files:
    file_name = file['name']
    if file_name.endswith('.npz'):
        download_url = file['download_url']
        response = requests.get(download_url, stream=True)
        response.raise_for_status()
        with open(deepblink_datasets_path / file_name, 'wb') as handle:
            for block in response.iter_content(1024):
                handle.write(block)

In [None]:
# Define define train-valid-test split.
splits = {'train': 42, 'valid': 9, 'test': 9}

In [None]:
# Loop over deepBlink datasets.
for deepblink_dataset in ('particle', 'microtubule', 'receptor', 'vesicle'):

    # Load deepBlink dataset.
    deepblink_ds = dict(np.load(deepblink_datasets_path / f'{deepblink_dataset}.npz', allow_pickle=True))

    # Loop through splits.
    for k, v in deepblink_ds.items():

        if k.startswith('x'):

            # Extract the split type.
            split = k.split('_')[1]

            # Get coordinates.
            v_y = deepblink_ds[f'y_{split}']

            # Use only SNR=1 images for microtubule, receptor, and vesicle datasets.
            len_v = len(v)
            if deepblink_dataset in ('microtubule', 'receptor', 'vesicle'):
                len_v = len_v // 4
                v = v[:len_v]
                v_y = v_y[:len_v]

            # Extract the desired size of this split.
            n = splits[split]

            # Determine the necessary stride.
            stride = (len_v - 1) // (n - 1)

            # Subset images.
            images_subset = np.empty(n, dtype=object)
            images_subset[:] = list(v[::stride][:n][:, 128:384, 128:384])
            deepblink_ds[k] = images_subset

            # Subset coordinates.
            coords_subset = v_y[::stride][:n]
            for i, coords in enumerate(coords_subset):
                coords = coords - np.array((128, 128))
                coords_subset[i] = coords[np.all(coords >= -0.5, axis=-1) & np.all(coords <= 255.5, axis=-1)]
            deepblink_ds[f'y_{split}'] = coords_subset

    # Save subset of the deepBlink dataset.
    np.savez(dataset_path / f'{deepblink_dataset}_subset.npz', **deepblink_ds)

### Generate a combined dataset.

In [None]:
# Create empty lists
x_train = []
y_train = []
x_valid = []
y_valid = []
x_test = []
y_test = []

# List all npz files.
npz_list = [file for file in dataset_path.glob('*.npz')]

# Loop over each npz file.
for npz in npz_list:

    # Load dataset.
    ds = np.load(npz, allow_pickle=True)

    # Add images and coords for each split.
    x_train.append(ds['x_train'])
    y_train.append(ds['y_train'])
    x_valid.append(ds['x_valid'])
    y_valid.append(ds['y_valid'])
    x_test.append(ds['x_test'])
    y_test.append(ds['y_test'])

# Concatenate across datasets.
x_train = np.concatenate(x_train)
y_train = np.concatenate(y_train)
x_valid = np.concatenate(x_valid)
y_valid = np.concatenate(y_valid)
x_test = np.concatenate(x_test)
y_test = np.concatenate(y_test)

# Pad and stack images.
x_train = pad_and_stack(x_train)
x_valid = pad_and_stack(x_valid)
x_test = pad_and_stack(x_test)

# Generate a combined dataset.
np.savez(combined_path / f'{dataset_path.stem}_combined.npz', x_train=x_train, y_train=y_train, x_valid=x_valid, y_valid=y_valid, x_test=x_test, y_test=y_test)