# Assigning transcripts to cells using the Euclidean Distance Transform
Frank Vernaillen, July 2022

This notebook explores an approach of assigning transcripts detected in a spatial transcriptomics experiment to the nearest cell.

Only the cells' nuclei were detected, so the precise extent of each cell is unknown. For the transcripts their (x, y) position is known. The cell nuclei were segmented and this segmentation is available as a labels image where all pixels belonging to a certain nucleus have the same value, which is unique per nucleus.

Since thousands of nuclei are detected and hundreds of thousands of transcripts, brute force calculation of the distance of each transcript to the closest nucleus pixel is computationally prohibitive. Instead, we first calculate the Euclidean Distance Transform (EDT) of the nucleus labels image using an EDT implementation which, next to the distance transform image (which we do not need), also returns for each input pixel the coordinates of the closest nucleus pixel. Once the EDT is generated, this allows us to create an image where each pixel value is the identity of the closest nucleus. A simple lookup of transcript positions (x, y) in this image yields the identity of the closest cell. So after EDT construction, finding the closest cell for a given transcript becomes an 𝒪(1) operation.

## Data

For this notebook we need an image with the labels of the segmented nuclei, and the text file with the transcript coordinates.

In [None]:
labels_file = r"E:\Frank\napari-sparrow\experiments\scipy-edt\04032022MartinLiver2_W0A1_DAPI-quicklabels.tiff"
transcripts_file = r"E:\Frank\napari-sparrow\experiments\scipy-edt\04032022MartinLiver2_W0A1_results.txt"

## Imports

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.ndimage
import tifffile

## Read nuclei segmentation

In [None]:
labels = tifffile.imread(labels_file)

## Calculate Euclidean Distance Transform (EDT)

We need an EDT implementation which not only calculates the distance transform itself, but also returns the closest "feature" (=nucleus) pixel. The scipy implementation has this [functionality](https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html).

The OpenCV implementation can also return the closest pixel, but it seems that in that case the distance transform is not exact. The [documentation](https://docs.opencv.org/4.6.0/d7/d1b/group__imgproc__misc.html#ga8a0b7fdfcb7a13dde018988ba3a43042) is not entirely clear about this though.

The Scipy implementation can handle the sizes of images we need for Resolve experiments (~15000 x 15000) quite well. For experiments where the images are significantly larger, like for Vizgen experiment, we may run into memory issues. Ideally we would like to use a dask implementation of EDT, both for handling large image sizes, and for its parallelism, but unfortunately [no such dask implementation](https://image.dask.org/en/latest/coverage.html) exists at this time.

In [None]:
def calculate_edt(labels):
    """Calculate the Euclidean Distance Transform (EDT) of the labels input image"""
    labels_binary = (labels == 0).astype(np.uint8)

    edt_distances, edt_indices = scipy.ndimage.distance_transform_edt(
        labels_binary, return_distances=True, return_indices=True
    )

    # The returned 'edt_indices' are the coordinates of the closest nucleus pixel.
    # So for any pixel (i,j) the closest nucleus pixel is given by (edt_indices[0][i,j], edt_indices[1][i,j]).
    edt_labels = labels[edt_indices[0], edt_indices[1]]

    return edt_labels, edt_distances, labels_binary

Calculate the EDT from the labels image with the nuclei.

In [None]:
%%time
edt_labels, edt_distances, labels_binary = calculate_edt(labels)

In [None]:
print(f"Number of nuclei: {np.max(labels)}")

In [None]:
def display_edt_result(labels, binarized_labels, edt_distances, edt_labels):
    # Get range of pixel values, to get consistent colormap for labels and edt_labels.
    vmin = np.min(edt_labels)
    vmax = np.max(edt_labels)

    fig, axs = plt.subplots(2, 2, figsize=(10, 10))
    axs[0, 0].imshow(labels, vmin=vmin, vmax=vmax, interpolation="nearest")
    axs[0, 1].imshow(binarized_labels, cmap="binary_r", interpolation="nearest")
    axs[1, 0].imshow(edt_labels, vmin=vmin, vmax=vmax, interpolation="nearest")
    axs[1, 1].imshow(edt_distances, cmap="binary_r", interpolation="nearest")
    fig.tight_layout()
    plt.show()

In [None]:
# Show Euclidean Distance Transform (EDT) result
display_edt_result(labels, labels_binary, edt_distances, edt_labels)

In [None]:
# Show close-up Euclidean Distance Transform (EDT) result of bottom right corner
w = 1000
display_edt_result(
    labels[-w:, -w:],
    labels_binary[-w:, -w:],
    edt_distances[-w:, -w:],
    edt_labels[-w:, -w:],
)

## Read transcripts

Let's read the transcript information in a pandas dataframe.

In [None]:
df = pd.read_csv(
    transcripts_file,
    delimiter="\t",
    header=None,
    usecols=[0, 1, 2, 3],
    names=["x", "y", "z", "gene"],
)

print(f"Number of transcripts: {df.shape[0]}")

## Assign transcripts to cells
This can now be done in 𝒪(1) using the labels obtained with the EDT.

In [None]:
transcript_coords = df[["y", "x"]].to_numpy()

We'll write a trivial function to look up transcript coordinates in the cell labels image created via the EDT.

In [None]:
def assign_transcripts_to_cells(transcript_coords, edt_labels):
    ii = transcript_coords[:, 0]
    jj = transcript_coords[:, 1]
    cell_ids = edt_labels[ii, jj]
    return cell_ids

Let us now assign the transcripts to the closest nucleus. Thanks to the EDT this operation now happens instantaneously for all transcripts.

In [None]:
%%time
cell_ids = assign_transcripts_to_cells(transcript_coords, edt_labels)

Store the cell ID of each transcript in the dataframe.

In [None]:
df["cell_id"] = cell_ids
df

## Visualize assignment of transcripts to cells

First define a function crop() which extract a region from the labels image and the transcripts. We will use it to zoom in on a smaller part of the tissue.

In [None]:
def crop(labels_image, transcripts_dataframe, crop_rect):
    toplefty, topleftx, bottomrighty, bottomrightx = tuple(crop_rect)

    # Crop labels image
    img = labels_image[toplefty:bottomrighty, topleftx:bottomrightx]

    # Keep only transcripts inside crop rectangle
    df = transcripts_dataframe.copy()
    df = df[(df.x >= topleftx) & (df.x < bottomrightx) & (df.y >= toplefty) & (df.y < bottomrighty)]

    # Translate transcript coordinates to match cropped labels image coordinates again
    df.x = df.x - topleftx
    df.y = df.y - toplefty

    return img, df

Let's also define a function that can create a matplotlib colormap with random colors. We will use the colormap to assign random colors to cells and their transcripts, so we can visually confirm that transcripts are correctly assigned to nearby cells.

In [None]:
def generate_random_colors(size=10240, seed=0):
    rng = np.random.default_rng(seed)
    colors = rng.random((size, 3))
    colors[0, :] = 1  # force first color to be white
    return colors

We can now check that our assignment of transcripts to nuclei using the Euclidean Distance Transform (EDT) works as expected.

In [None]:
# Crop a closeup region
crop_rect = [8700, 4700, 10400, 6400]
cropped_labels, cropped_df = crop(labels, df, crop_rect)

# Find labels range for colormap, needed to get imshow() and scatter() use the same colors
vmin = 0
vmax = np.max(cropped_df.cell_id)

# Make colormaps, slightly different one for the labels so we can still see the transcripts on top of them
colors = generate_random_colors(
    size=2 * vmax, seed=1
)  # pick a large enough colormap size so that nearby cell ID values still get mapped to different colors
labels_cmap = plt.cm.colors.ListedColormap(np.vstack((colors[0, :], colors[1:,] * 0.8)))
transcripts_cmap = plt.cm.colors.ListedColormap(colors)

# Plot transcripts and nuclei
plt.figure(figsize=(15, 15))
plt.imshow(cropped_labels, cmap=labels_cmap, interpolation="nearest", vmin=vmin, vmax=vmax)
plt.scatter(
    cropped_df.x,
    cropped_df.y,
    c=cropped_df.cell_id,
    s=0.5,
    cmap=transcripts_cmap,
    vmin=vmin,
    vmax=vmax,
)
plt.show()