In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pyrootutils

base_path = pyrootutils.setup_root(
    search_from=".",
    indicator=[".gitignore"],
    project_root_env_var=True,  # set the PROJECT_ROOT environment variable to root directory
    dotenv=True,  # load environment variables from .env if exists in root directory
    pythonpath=True,  # add root directory to the PYTHONPATH (helps with imports)
    cwd=True,  # change current working directory to the root directory (helps with filepaths)
)
import sys
import pandas as pd
import scanpy as sc
import squidpy as sq
import matplotlib.pyplot as plt
import seaborn as sns
import json
import wandb
import matplotlib
import matplotlib as mpl
from itertools import combinations
import os
import shutil

from src.utils.metrics import *
from src.data.utils import *
from src.utils.eval import *
from src.utils.wandb import *

## Define helper functions

## Define paths

In [3]:
# experiment_name = "experiment_mouse_st"
# sc_path = "../data/single-cell/Allenbrain_forSimulation_uniquect.h5ad"
# st_path = "../data/spatial/V1_Mouse_Brain_Sagittal_Anterior.h5ad"

experiment_names = [
    "experiment_kidney_slideSeq_v2_105",
    "experiment_kidney_slideSeq_v2_UMOD-WT.WT-2a_resolution75",
    "experiment_heart_seqFISH_embryo1_resolution0.11-new",
    # "experiment_heart_seqFISH_embryo1_resolution0.11",
]
st_paths = [
    "./data/spatial/simulations_kidney_slideSeq_v2/UMOD-KI.KI-4b_resolution105.h5ad",
    "./data/spatial/simulations_kidney_slideSeq_v2/UMOD-WT.WT-2a_resolution75.h5ad",
    "./data/spatial/simulations_heart_seqFISH/embryo1_resolution0.11.h5ad",
    "./data/spatial/lymph_node/st_lymph.h5ad",
]
sc_paths = [
    "./data/spatial/kidney_slideSeq_v2/UMOD-WT.WT-2a.h5ad",
    "./data/spatial/kidney_slideSeq_v2_105.h5ad",
    "./data/spatial/heart_seqFISH/embryo1.h5ad",
    "./data/single-cell/lymph_node/sc_lymph.h5ad",
]
experiment_paths = [f"{base_path}/experiments/{experiment_name}" for experiment_name in experiment_names]

dataset_map = {
    "slideSeq-4b": "UMOD-KI.KI-4b_resolution105", 
    "slideSeq-2a": "UMOD-WT.WT-2a_resolution75",
    "seqFISH": "embryo1_resolution0.11",
}
dataset_path_map = {k: v for k, v in zip(dataset_map.keys(), st_paths)}

## Visualize data

In [4]:
st_data = sc.read_h5ad(st_paths[-1])

  utils.warn_names_duplicates("var")


In [5]:
sc_data = sc.read_h5ad(sc_paths[-1])

In [6]:
sc_data.obs["Subset"].unique()

['T_CD4+_TfH', 'T_CD4+_naive', 'T_CD8+_CD161+', 'T_CD4+_TfH_GC', 'DC_CCR7+', ..., 'B_GC_LZ', 'B_GC_DZ', 'B_preGC', 'FDC', 'B_GC_prePB']
Length: 34
Categories (34, object): ['B_Cycling', 'B_GC_DZ', 'B_GC_LZ', 'B_GC_prePB', ..., 'T_TIM3+', 'T_TfR', 'T_Treg', 'VSMC']

In [61]:
# sc_data = sc.read_h5ad(sc_paths[0])
# # convert cell_type collumn into separate dataframe
# for cell_type in sc_data.obs["cell_type"].unique():
#     sc_data.obs[cell_type] = (sc_data.obs["cell_type"] == cell_type).astype(int)
# sc.pl.spatial(sc_data, color=sc_data.obs["cell_type"].unique(), show=True)

In [None]:
# load sample names and celltype names
# celltypes = load_celltypes(f"{experiment_path}/datasets/celltypes.txt")
# sample_names = load_sample_names(f"{experiment_path}/datasets/sample_names.txt")

In [None]:
# for experiment_name, st_path, sc_path in zip(experiment_names, st_paths, sc_paths):
#     experiment_path = f"{base_path}/experiments/{experiment_name}"
#     # load original dissect results
#     dissect_results, ensemble_result = load_dissect_results(experiment_path)
#     groundtruth = load_groundtruth(st_path)     

## Load new results

In [7]:
tags = ["hybrid", "transformer", "None"]
tags = ["latestv5"]
runs = []
for tag in tags:
    filter_ = {"tags": {"$in": [tag]}, "state": "finished"}
    runs_per_tag = list(get_filtered_runs(filter=filter_))
    runs.extend(runs_per_tag)
print(f"Loaded {len(runs)} runs")
runs = runs[2::]

Loaded 5 runs


In [8]:
run_names = [run.name for run in runs]
run_tags = [run.tags for run in runs]
new_results = [get_result_for_run_name(run_name) for run_name in tqdm(run_names)]


 33%|███▎      | 1/3 [00:00<00:01,  1.67it/s]

Loaded media/table/predictions-step-5000_519_7b7b9741183900146ecb.table.json


 67%|██████▋   | 2/3 [00:01<00:00,  1.61it/s]

Loaded media/table/predictions-step-5000_519_7d20925beb878bb7c625.table.json


100%|██████████| 3/3 [00:01<00:00,  1.61it/s]

Loaded media/table/predictions-step-5000_519_9d932bf63d6adca30bc3.table.json





In [None]:
new_results[0].shape


(1640, 23)

In [9]:
dataset_names = [extract_dataset_name(run.config["data/st_path"], dataset_map) for run in runs]
learning_rates = [run.config["model/learning_rate"] for run in runs]
weight_decays = [run.config["model/weight_decay"] for run in runs]
max_steps = [run.config["trainer/max_steps"] for run in runs]
# get unique learning rates
unique_learning_rates = list(set(learning_rates))
# later define more sophisticated check
base_names = [get_base_name(tag) for tag in run_tags]

method_names = get_method_names(base_names, learning_rates, weight_decays, max_steps)
pd.DataFrame({"method": method_names, "dataset": dataset_names, "run_name": run_names})

Unnamed: 0,method,dataset,run_name
0,gnn-0,seqFISH,royal-haze-297
1,gnn-0,slideSeq-4b,logical-firefly-296
2,gnn-0,slideSeq-2a,feasible-dew-295


## Load original DISSECT results

In [12]:
# dissect_results, ensemble_result = load_dissect_results(experiment_path)
ensemble_results = [load_dissect_results(experiment_path)[1] for experiment_path in experiment_paths]
dissect_dataset_names = [extract_dataset_name(st_path, dataset_map) for st_path in st_paths[0:len(experiment_paths)]]
dissect_names = len(dissect_dataset_names) * ["DISSECT (ensemble)"]

In [None]:
for result in ensemble_results:
    print(len(result.index))

1803
3244
1640


## Load GraphST results

In [13]:
graph_st_result_paths = [
    "/data/GraphST/Data/slideseq_105_result.h5ad",
    "/data/GraphST/Data/slideseq_75_result.h5ad",
    "/data/GraphST/Data/seqfish_result.h5ad",
]
graphst_dataset_names = ["slideSeq-4b", "slideSeq-2a", "seqFISH"]
graph_st_names = len(dataset_names) * ["GraphST"]
graph_st_results = []
for path in graph_st_result_paths:
    result = sc.read_h5ad(path)
    graph_st_results.append(result.obs[result.obs.columns[2::]])
    print(len(result.obs.index))

1803
3244
1640


## Combine results

In [14]:
all_dataset_names = dataset_names + dissect_dataset_names + graphst_dataset_names
all_method_names = method_names + dissect_names + graph_st_names
all_results = new_results + ensemble_results + graph_st_results

## Evaluate results

### Compare with ground truth

In [15]:
all_cellwise_results = []
all_samplewise_results = []
for k, v in dataset_map.items():
    print(f"Loading groundtruth for {k}...")
    groundtruth = load_groundtruth(dataset_path_map[k])
    print(f"Loaded groundtruth {dataset_path_map[k]}...")

    filtered_results = filter_data_by_dataset(k, all_dataset_names, all_results)
    filtered_method_names = filter_data_by_dataset(k, all_dataset_names, all_method_names)
    

    cellwise_results = compare_methods_new(
        filtered_results, groundtruth, methods=filtered_method_names, samplewise=False
    )
    cellwise_results["dataset"] = k
    all_cellwise_results.append(cellwise_results)
    samplewise_results = compare_methods_new(
        filtered_results, groundtruth, methods=filtered_method_names, samplewise=True
    )
    samplewise_results["dataset"] = k
    all_samplewise_results.append(samplewise_results)
all_cellwise_results = pd.concat(all_cellwise_results, ignore_index=True)
all_samplewise_results = pd.concat(all_samplewise_results, ignore_index=True)

Loading groundtruth for slideSeq-4b...
Loaded groundtruth ./data/spatial/simulations_kidney_slideSeq_v2/UMOD-KI.KI-4b_resolution105.h5ad...
Loading groundtruth for slideSeq-2a...
Loaded groundtruth ./data/spatial/simulations_kidney_slideSeq_v2/UMOD-WT.WT-2a_resolution75.h5ad...
Loading groundtruth for seqFISH...
Loaded groundtruth ./data/spatial/simulations_heart_seqFISH/embryo1_resolution0.11.h5ad...


### Visualize in tabluar form

In [16]:
all_cellwise_grouped = all_cellwise_results.groupby(["Method", "Fold"]).mean(numeric_only=True)
all_cellwise_grouped_std = all_cellwise_results.groupby(["Method", "Fold"]).std(numeric_only=True)
all_samplewise_grouped = all_samplewise_results.groupby(["Method", "Fold"]).mean(numeric_only=True)
all_samplewise_grouped_std = all_samplewise_results.groupby(["Method", "Fold"]).std(numeric_only=True)
all_mean_results = pd.concat([all_cellwise_grouped, all_samplewise_grouped], axis=1)
all_std_results = pd.concat([all_cellwise_grouped_std, all_samplewise_grouped_std], axis=1)

In [17]:
all_mean_results

Unnamed: 0_level_0,Unnamed: 1_level_0,Correlation,RMSE,CCC,Correlation (samplewise),RMSE (samplewise),CCC (samplewise),JSD
Method,Fold,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
DISSECT (ensemble),0,0.70123,0.091283,0.537234,0.826116,0.094398,0.70511,0.177276
GraphST,0,0.549386,0.089288,0.397738,0.806167,0.090729,0.759979,0.134762
gnn-0,0,0.726947,0.086092,0.631525,0.808678,0.093143,0.75514,0.145164


### Visualize as plot

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(16, 5), sharey="row", sharex="col")
axs = np.ravel(axs)
metric = "Correlation"
relevant_methods = ["DISSECT (ensemble)", "gnn-3"]
dataset = "slideSeq-4b"
methods_mask = all_cellwise_results["Method"].isin(relevant_methods)
data_mask = all_cellwise_results["dataset"] == dataset
# sns.barplot(all_cellwise_results.loc[methods_mask, :], x="Celltype", y=metric, hue="Method", ax=axs[0])
axs[0].set_title("Performance per celltype for selected methods")
sns.barplot(all_cellwise_results.loc[methods_mask & data_mask, :], x="Celltype", y=metric, hue="Method", ax=axs[1])
for ax in axs[-2:]:
    ax.set_xticklabels(ax.get_xticklabels(), rotation=30, ha="right")
for ax in axs:
    ax.grid(True)
plt.show()

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(16, 5), sharey="row", sharex="col")
axs = np.ravel(axs)
metric = "Correlation"
relevant_methods = ["DISSECT (ensemble)", "gnn-3"]
methods_mask = all_cellwise_results["Method"].isin(relevant_methods)
sns.barplot(all_cellwise_results.loc[methods_mask, :], x="Celltype", y=metric, hue="Method", ax=axs[0])
axs[0].set_title("Performance per celltype for selected methods")
sns.barplot(all_cellwise_results, x="Method", y=metric, hue="dataset", ax=axs[1])
for ax in axs[-2:]:
    ax.set_xticklabels(ax.get_xticklabels(), rotation=30, ha="right")
for ax in axs:
    ax.grid(True)
plt.show()

### Single dataset analysis

In [None]:
cellwise_results = compare_methods_new(
    all_results, groundtruth, methods=all_method_names, samplewise=False
)
samplewise_results = compare_methods_new(
    all_results, groundtruth, methods=all_method_names, samplewise=True
)

In [None]:
cellwise_grouped = cellwise_results.groupby(["Method", "Fold"]).mean(numeric_only=True)
samplewise_grouped = samplewise_results.groupby(["Method", "Fold"]).mean(numeric_only=True)
pd.concat([cellwise_grouped, samplewise_grouped], axis=1)

In [None]:
# plot performance per celltype
fig, axs = plt.subplots(2, 2, figsize=(16, 10), sharey="row", sharex="col")
axs = np.ravel(axs)
sns.barplot(cellwise_results, x="Celltype", y="Correlation", hue="Method", ax=axs[0])
mean_cellwise_results = cellwise_results.groupby(
    ["Method", "Fold"], as_index=False
).mean(numeric_only=True)
sns.barplot(mean_cellwise_results, x="Method", y="Correlation", ax=axs[1])

sns.barplot(cellwise_results, x="Celltype", y="RMSE", hue="Method", ax=axs[2])
mean_cellwise_results = cellwise_results.groupby(
    ["Method", "Fold"], as_index=False
).mean(numeric_only=True)
sns.barplot(mean_cellwise_results, x="Method", y="RMSE", ax=axs[3])
# rotate xticks
for ax in axs[-2:]:
    ax.set_xticklabels(ax.get_xticklabels(), rotation=30, ha="right")
for ax in axs:
    ax.grid(True)

plt.savefig(
    f"./figures/evaluation_{experiment_name}_cellwise.png", dpi=200, bbox_inches="tight"
)
plt.show()

In [None]:
# samplewise results
# plot performance per celltype
metrics = ["Correlation", "RMSE"]
for metric in metrics:
    fig, axs = plt.subplots(1, 2, figsize=(16, 5), sharey="row", sharex="col")
    sns.boxplot(
        samplewise_results,
        x="Method",
        y=metric + " (samplewise)",
        hue="Fold",
        ax=axs[0],
    )
    mean_samplewise_results = samplewise_results.groupby(
        ["Method", "Fold"], as_index=False
    ).mean()
    sns.barplot(
        mean_samplewise_results, x="Method", y=metric + " (samplewise)", ax=axs[1]
    )
    # plt.savefig(
    #     f"./figures/evaluation_{experiment_name}_cellwise.png", dpi=200, bbox_inches="tight"
    # )
    plt.show()

## Qualitative analysis

In [None]:
# load data
st_path = "./data/spatial/V1_Mouse_Brain_Sagittal_Anterior.h5ad"
st_data = sc.read_h5ad(st_path)

# run_name = "super-armadillo-244"
# run_name = "captain-maquis-285"
run_name = "defiant-frontier-290"
run_name = "xindi-q-292"
run_name = "robust-sun-298"
result = get_result_for_run_name(run_name, project="dissect-spatial")

st_data.obs = pd.concat([st_data.obs, result], axis=1)

In [None]:
# for col in result.columns:
#     fig, axs = plt.subplots(1, 1, figsize=(10, 5))
# sc.pl.spatial(st_data, color=result.columns, show=True, save=f"mouse_brain_new_v2.png", vmin=0.0, vmax=1.0)
fig = sc.pl.spatial(st_data, color=result.columns, show=False, save=False, vmin=0.0, vmax=1.0)
plt.savefig("./figures/mouse.png", dpi=300, bbox_inches="tight")
plt.show()

In [None]:
# load data
st_path = "./data/spatial/V1_Mouse_Brain_Sagittal_Anterior.h5ad"
st_data = sc.read_h5ad(st_path)

dissect_result = load_dissect_results("experiments/experiment_mouse_st")[1]
st_data.obs = pd.concat([st_data.obs, dissect_result], axis=1)
fig = sc.pl.spatial(st_data, color=result.columns, show=False, save=False, vmin=0.0, vmax=1.0)
plt.savefig("./figures/mouse_brain_dissect_ensemble.png", dpi=300, bbox_inches="tight")
plt.show()

## Old analysis

### Predicted max celltype per spot

In [None]:
for k, df in enumerate(results_per_seed):
    max_celltypes = df.idxmax(axis=1)
    st_data.obs[f"celltype-{k}"] = max_celltypes
    sq.pl.spatial_scatter(st_data, color=f"celltype-{k}")

### Celltype distribution across tissue

In [None]:
st_data.obs[celltypes] = new_result[celltypes]
sq.pl.spatial_scatter(st_data, color="L5 IT")
plt.savefig(
    f"{base_path}/figures/single_celltype_{experiment_name}.png",
    dpi=200,
    bbox_inches="tight",
)

In [None]:
st_data.obs[celltypes] = ensemble_result[celltypes]
sq.pl.spatial_scatter(st_data, color=celltypes)

### Comparison between dissect runs for different seeds

In [None]:
all_corrs = []
for df_1 in results_per_seed:
    for df_2 in results_per_seed:
        all_corrs.append(calc_mean_corr_df(df_1, df_2, verbose=1)[0])
    # calc correlation
all_corrs = np.reshape(all_corrs, (5, 5))

In [None]:
labels = [f"Seed {i}" for i in range(5)]

In [None]:
def heatmap(
    data, row_labels, col_labels, ax=None, cbar_kw=None, cbarlabel="", **kwargs
):
    """
    Create a heatmap from a numpy array and two lists of labels.

    Parameters
    ----------
    data
        A 2D numpy array of shape (M, N).
    row_labels
        A list or array of length M with the labels for the rows.
    col_labels
        A list or array of length N with the labels for the columns.
    ax
        A `matplotlib.axes.Axes` instance to which the heatmap is plotted.  If
        not provided, use current axes or create a new one.  Optional.
    cbar_kw
        A dictionary with arguments to `matplotlib.Figure.colorbar`.  Optional.
    cbarlabel
        The label for the colorbar.  Optional.
    **kwargs
        All other arguments are forwarded to `imshow`.
    """

    if ax is None:
        ax = plt.gca()

    if cbar_kw is None:
        cbar_kw = {}

    # Plot the heatmap
    im = ax.imshow(data, **kwargs)

    # Create colorbar
    cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
    cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")

    # Show all ticks and label them with the respective list entries.
    ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
    ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)

    # Let the horizontal axes labeling appear on top.
    ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor")

    # Turn spines off and create white grid.
    ax.spines[:].set_visible(False)

    ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
    ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
    ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
    ax.tick_params(which="minor", bottom=False, left=False)

    return im, cbar


def annotate_heatmap(
    im,
    data=None,
    valfmt="{x:.2f}",
    textcolors=("black", "white"),
    threshold=None,
    **textkw
):
    """
    A function to annotate a heatmap.

    Parameters
    ----------
    im
        The AxesImage to be labeled.
    data
        Data used to annotate.  If None, the image's data is used.  Optional.
    valfmt
        The format of the annotations inside the heatmap.  This should either
        use the string format method, e.g. "$ {x:.2f}", or be a
        `matplotlib.ticker.Formatter`.  Optional.
    textcolors
        A pair of colors.  The first is used for values below a threshold,
        the second for those above.  Optional.
    threshold
        Value in data units according to which the colors from textcolors are
        applied.  If None (the default) uses the middle of the colormap as
        separation.  Optional.
    **kwargs
        All other arguments are forwarded to each call to `text` used to create
        the text labels.
    """

    if not isinstance(data, (list, np.ndarray)):
        data = im.get_array()

    # Normalize the threshold to the images color range.
    if threshold is not None:
        threshold = im.norm(threshold)
    else:
        threshold = im.norm(data.max()) / 2.0

    # Set default alignment to center, but allow it to be
    # overwritten by textkw.
    kw = dict(horizontalalignment="center", verticalalignment="center")
    kw.update(textkw)

    # Get the formatter in case a string is supplied
    if isinstance(valfmt, str):
        valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)

    # Loop over the data and create a `Text` for each "pixel".
    # Change the text's color depending on the data.
    texts = []
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
            text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
            texts.append(text)

    return texts

In [None]:
fig, ax = plt.subplots()

im, cbar = heatmap(
    all_corrs, labels, labels, ax=ax, cmap="Blues", cbarlabel="Pearson correlation"
)
texts = annotate_heatmap(im, valfmt="{x:.2f}")
ax.set_title("Correlation between seeded runs in original DISSECT")

fig.tight_layout()
plt.savefig(
    "../figures/corr_heatmap_dissect_original.png", dpi=200, bbox_inches="tight"
)
plt.show()

### Comparison with new implementation

In [None]:
new_corrs = []
for df in results_per_seed:
    new_corrs.append(calc_mean_corr_df(df, new_result)[0])
new_corrs = np.reshape(new_corrs, (5, 1))

In [None]:
fig, ax = plt.subplots()
ax.boxplot(new_corrs)
ax.set_xticklabels(" ")
ax.set_title("Mean correlation between seeded runs and new DISSECT")
fig.tight_layout()
plt.savefig("../figures/corr_dissect_new_vs_old.png", dpi=200, bbox_inches="tight")
plt.show()

In [None]:
print(np.mean(new_corrs))
print(np.std(new_corrs))
print(np.max(new_corrs))
print(np.min(new_corrs))

## Compare per cell type

In [None]:
(
    mean_corr,
    corrs,
    mean_corr_sample,
    corrs_sample,
    mean_rmse,
    rmses,
    mean_rmse_sample,
    rmses_sample,
) = calc_metrics_df(ensemble_result, y_real, verbose=1, exclude_cols=None)

In [None]:
# plot corrs per cell type
fig, ax = plt.subplots()
cell_types = y_real.columns
ax.bar(cell_types, corrs)

In [None]:
for col in ensemble_result.columns:
    st_data.obs[f"{col} pred"] = ensemble_result[col]

In [None]:
for col in ensemble_result.columns:
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    sc.pl.spatial(st_data, color=col, ax=axs[0], show=False, cmap="Reds")
    sc.pl.spatial(st_data, color=f"{col} pred", ax=axs[1], cmap="Reds")

In [None]:
for col in ensemble_result.columns:
    st_data.obs[f"{col} pred"] = ensemble_result[col]