# Create UMAP of single cell features [EBSS starvation data]

In [None]:
import glob
from pathlib import Path
import numpy as np
import pandas as pd
import umap
import matplotlib.pyplot as plt

from livecellx.core.datasets import LiveCellImageDataset
from livecellx.sample_data import tutorial_three_image_sys

from livecellx.trajectory.feature_extractors import compute_haralick_features, compute_skimage_regionprops
from livecellx.preprocess.utils import normalize_img_to_uint8
from livecellx.core.parallel import parallelize
from livecellx.core.io_sc import prep_scs_from_mask_dataset
from livecellx.core.single_cell import create_sc_table
import livecellx.core.single_cell
from livecellx.core.single_cell import SingleCellStatic
from livecellx.core.io_utils import LiveCellEncoder


dataset_dir_path = Path(
    "../datasets/EBSS_Starvation/tif_STAV-A549_VIM_24hours_NoTreat_NA_YL_Ti2e_2022-12-21/XY16/"
)

mask_dataset_path = Path("../datasets/EBSS_Starvation/tif_STAV-A549_VIM_24hours_NoTreat_NA_YL_Ti2e_2022-12-21/out/XY16/seg")

mask_dataset = LiveCellImageDataset(mask_dataset_path, ext="png")
time2url = sorted(glob.glob(str((Path(dataset_dir_path) / Path("*_DIC.tif")))))
time2url = {i: path for i, path in enumerate(time2url)}
dic_dataset = LiveCellImageDataset(time2url=time2url, ext="tif")

In [None]:

out_dir = Path("tutorial_results/umap_EBSS_STARV")
out_dir.mkdir(exist_ok=True, parents=True)

Compute the features  
Read the features in the next section if you already computed the features.

In [None]:
from livecellx.core.io_sc import prep_scs_from_mask_dataset
scs = prep_scs_from_mask_dataset(mask_dataset, dic_dataset)

In [None]:
from livecellx.trajectory.feature_extractors import compute_haralick_features, compute_skimage_regionprops
from livecellx.preprocess.utils import normalize_img_to_uint8
from livecellx.core.parallel import parallelize

inputs = []
for sc in scs:
    inputs.append({
        "sc": sc,
        "preprocess_img_func": normalize_img_to_uint8,
        "sc_level_normalize": True
    })

def compute_skimage_regionprops_wrapper(sc, preprocess_img_func=normalize_img_to_uint8, sc_level_normalize=True):
    return compute_skimage_regionprops(sc, preprocess_img_func=preprocess_img_func, sc_level_normalize=sc_level_normalize), sc

outputs = parallelize(compute_skimage_regionprops_wrapper, inputs, cores=16)
features = [output[0] for output in outputs]
scs = [output[1] for output in outputs]

# # for sequential version without using parallelize
# for sc in scs:
#     features = compute_skimage_regionprops(sc, preprocess_img_func=normalize_img_to_uint8, sc_level_normalize=True)

In [None]:
from livecellx.core.single_cell import create_sc_table
import livecellx.core.single_cell
import importlib
importlib.reload(livecellx.core.single_cell)

sc_feature_table = create_sc_table(scs)
sc_feature_table_with_time_and_id = create_sc_table(scs, add_time=True, add_sc_id=True)
sc_feature_table_with_time_and_id.to_csv(out_dir / Path("sc_feature_table_with_time_and_id.csv"), index=False)
sc_feature_table[:2]

In [None]:
from livecellx.core.single_cell import SingleCellStatic
from livecellx.core.io_utils import LiveCellEncoder
import importlib
importlib.reload(livecellx.core.single_cell)
importlib.reload(livecellx.core.io_utils)

livecellx.core.single_cell.SingleCellStatic.write_single_cells_json(scs, out_dir / Path("scs.json"), dataset_dir=out_dir / Path("dataset"))

### Read the features if computed before

In [None]:
import importlib
importlib.reload(livecellx.core.single_cell)
importlib.reload(livecellx.core.io_utils)

In [None]:
import pandas as pd
sc_feature_table_with_time_and_id = pd.read_csv(out_dir / Path("sc_feature_table_with_time_and_id.csv"))
scs = SingleCellStatic.load_single_cells_json(out_dir / Path("scs.json"))

## Perform UMAP on features

In [None]:
# sort by correlation absolute value
# feature_corr_df = feature_corr_df.reindex(feature_corr_df["corr"].abs().sort_values(ascending=False).index)
# feature_corr_df[:-5]

In [None]:
centroid_features = ['skimage_centroid-0', 'skimage_centroid-1', 'skimage_centroid_weighted-0', 'skimage_centroid_weighted-1']

### Plotting unnormalized feature

In [None]:
unnormalized_img_features = create_sc_table(scs, normalize_features=False)
unnormalized_img_features.dropna(axis=1, how="all", inplace=True)
unnormalized_img_features.drop(columns=centroid_features, inplace=True)
# print all the feature columns with NAs
feature_na_cols = unnormalized_img_features.columns[unnormalized_img_features.isna().any()].tolist()
print("feature_na_cols: ", feature_na_cols)

unnormalized_img_features = unnormalized_img_features.dropna(axis=1, how="all")
reducer = umap.UMAP()
unnormalized_embedding = reducer.fit_transform(unnormalized_img_features)


In [None]:

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
scatter_time = axes[0].scatter(
    unnormalized_embedding[:, 0],
    unnormalized_embedding[:, 1],
    c=[sc.timeframe for sc in scs],
)
colorbar = fig.colorbar(scatter_time, ax=axes[0], label="Time")

scatter_area = axes[1].scatter(
    unnormalized_embedding[:, 0],
    unnormalized_embedding[:, 1],
    c=unnormalized_img_features["skimage_centroid_weighted_local-0"],
)
colorbar = fig.colorbar(scatter_area, ax=axes[1], label="Area")


In [None]:
import numpy as np
import pandas as pd
# calculate correlation matrix between features and UMAP axis

def compute_dims_corr(reduced_dims: np.array, sc_feature_table: pd.DataFrame):
    feature_cols = sc_feature_table.columns
    feature_corr_df = pd.DataFrame()
    for feature in feature_cols:
        for dim in range(reduced_dims.shape[1]):
            _embedding = reduced_dims[:, dim]

            # suffer from NAN
            # corr = np.corrcoef(np.array(sc_feature_table[feature]), _embedding)[0, 1]

            # avoid and exclude NA values
            _tmp_df = pd.DataFrame({"embedding": _embedding, "feature": sc_feature_table[feature]})
            all_corrs = _tmp_df.corr()
            corr = all_corrs["feature"].loc["embedding"]
            new_df = pd.DataFrame({"feature": [feature], "corr": [corr], "dim": [dim]})
            feature_corr_df = pd.concat([feature_corr_df, new_df], ignore_index=True)
    return feature_corr_df

In [None]:
print(np.isnan(unnormalized_embedding).any())

In [None]:
unnormalized_corr_df = compute_dims_corr(unnormalized_embedding, unnormalized_img_features)

In [None]:
unnormalized_corr_df[:3]

In [None]:
# for each dimension, sort by correlation absolute value

# sort the unnormalized_corr_df DataFrame by the absolute value of the corr column for each dim
unnormalized_corr_df = unnormalized_corr_df.groupby('dim').apply(lambda x: x.iloc[x['corr'].abs().argsort()[::-1]])


# reset the index of the sorted DataFrame
unnormalized_corr_df = unnormalized_corr_df.reset_index(drop=True)

# group the unnormalized_corr_df DataFrame by the dim column
grouped_df = unnormalized_corr_df.groupby('dim')

# define a lambda function to extract the sorted feature and correlation values for each group
get_sorted_values = lambda x: (x['feature'].values, x['corr'].values)


# apply the lambda function to each group and convert the result to a Pandas Series
dim1_series = grouped_df.apply(lambda x: pd.Series(get_sorted_values(x)[0], name='feature'))
dim2_series = grouped_df.apply(lambda x: pd.Series(get_sorted_values(x)[1], name='corr'))
type(dim1_series)

In [None]:
dim1_series.columns

### Plotting normalized feature

In [None]:
normalized_img_features = create_sc_table(scs, normalize_features=True)
# drop na
normalized_img_features = normalized_img_features.dropna(axis=1, how="all")
reducer = umap.UMAP()
normalized_embedding = reducer.fit_transform(normalized_img_features)

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
scatter_time = axes[0].scatter(
    normalized_embedding[:, 0],
    normalized_embedding[:, 1],
    c=[sc.timeframe for sc in scs],
)
colorbar = fig.colorbar(scatter_time, ax=axes[0], label="Time")

scatter_area = axes[1].scatter(
    normalized_embedding[:, 0],
    normalized_embedding[:, 1],
    c=normalized_img_features["skimage_centroid_weighted_local-0"],
)
colorbar = fig.colorbar(scatter_area, ax=axes[1], label="Area")


In [None]:
import plotly.graph_objs as go
import plotly.subplots as sp

normalized_img_features = create_sc_table(scs, normalize_features=True)
normalized_img_features = normalized_img_features.dropna(axis=1, how="all")
reducer = umap.UMAP()
normalized_embedding = reducer.fit_transform(normalized_img_features)

In [None]:
%matplotlib widget
fig = sp.make_subplots(rows=1, cols=2, subplot_titles=("Time", "Area"))

scatter_time = go.Scatter(
    x=normalized_embedding[:, 0],
    y=normalized_embedding[:, 1],
    mode="markers",
    marker=dict(color=[sc.timeframe for sc in scs], colorbar=dict(title="Time", x=0.45, yanchor="middle", len=0.5)),
    text=[f"Timeframe: {sc.timeframe}<br>sc id: {sc.id}" for sc in scs],
    customdata=[sc.id for sc in scs],
)

# define a callback function to display the image when a scatter point is clicked
def on_click(trace, points, state):
    print("<debug> points: ", points)
    if points.point_inds:
        sc_id = points.points[0].customdata
        print("<debug> sc_id: ", sc_id)
        # do something with the single cell object

scatter_area = go.Scatter(
    x=normalized_embedding[:, 0],
    y=normalized_embedding[:, 1],
    mode="markers",
    marker=dict(color=normalized_img_features["skimage_centroid_weighted_local-0"], colorbar=dict(title="Area", x=1, yanchor="middle", len=0.5)),
)
fig.append_trace(scatter_time, row=1, col=1)
fig.add_trace(scatter_area, row=1, col=2)

fig.update_layout(height=500, width=1000, title_text="UMAP Embedding", clickmode="event")