# CellDART Example Code: mouse brain 
## (10x Visium of anterior mouse brain + scRNA-seq data of mouse brain)

In [1]:
import glob
import os
import pickle
from math import ceil

import anndata as ad
import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import scanpy as sc
import seaborn as sns
import scipy.stats as ss

from sklearn.model_selection import train_test_split
from sklearn.metrics import RocCurveDisplay
from sklearn.metrics import mean_squared_error
from sklearn import metrics

import tensorflow as tf  # TensorFlow registers PluggableDevices here.

from CellDART import da_cellfraction
from CellDART.utils import random_mix

from src.utils.data_loading import load_spatial, load_sc, get_selected_dir

from tqdm.autonotebook import tqdm


  from tqdm.autonotebook import tqdm


In [2]:
tf.config.list_physical_devices()


[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [4]:
TRAIN_USING_ALL_ST_SAMPLES = False
N_MARKERS = 20
ALL_GENES = False

ST_SPLIT = False
N_SPOTS = 20000
N_MIX = 8
SCALER_NAME = "celldart"


SAMPLE_ID_N = "151673"
INITIAL_TRAIN_EPOCHS = 10


BATCH_SIZE = 512
ALPHA = 0.6
ALPHA_LR = 5
N_ITER = 3000

DATA_DIR = "../AGrEDA/data"
# DATA_DIR = "./data_combine/"

BOOTSTRAP = False
BOOTSTRAP_ROUNDS = 10
BOOTSTRAP_ALPHAS = [0.6, 1 / 0.6]

MODEL_NAME = "CellDART"
MODEL_VERSION = "V1"


In [5]:
model_folder = os.path.join("model", MODEL_NAME, MODEL_VERSION)

if not os.path.isdir(model_folder):
    os.makedirs(model_folder)
    print(model_folder)


## 1. Data load
### load scanpy data - 10x datasets

In [6]:
# Load spatial data
mat_sp_d, mat_sp_train, st_sample_id_l = load_spatial(
    get_selected_dir(DATA_DIR, N_MARKERS, ALL_GENES),
    SCALER_NAME,
    train_using_all_st_samples=TRAIN_USING_ALL_ST_SAMPLES,
    st_split=ST_SPLIT,
)

# Load sc data
sc_mix_d, lab_mix_d, sc_sub_dict, sc_sub_dict2 = load_sc(
    get_selected_dir(DATA_DIR, N_MARKERS, ALL_GENES),
    SCALER_NAME,
    n_mix=N_MIX,
    n_spots=N_SPOTS,
)


In [7]:
advtrain_folder = os.path.join(model_folder, "advtrain")
pretrain_folder = os.path.join(model_folder, "pretrain")
if not os.path.isdir(advtrain_folder):
    os.makedirs(advtrain_folder)
if not os.path.isdir(pretrain_folder):
    os.makedirs(pretrain_folder)


In [8]:
# st_sample_id_l = [SAMPLE_ID_N]


In [9]:
if TRAIN_USING_ALL_ST_SAMPLES:
    print(f"Adversarial training for all ST slides")
    embs, clssmodel, clssmodel_noda = da_cellfraction.train(
        sc_mix_d["train"],
        lab_mix_d["train"],
        mat_sp_train_s,
        alpha=ALPHA,
        alpha_lr=5,
        emb_dim=64,
        batch_size=BATCH_SIZE,
        n_iterations=N_ITER,
        initial_train=True,
        initial_train_epochs=INITIAL_TRAIN_EPOCHS,
    )
elif BOOTSTRAP:
    pred_sp_boostrap_d = {}

    outer = tqdm(total=len(BOOTSTRAP_ALPHAS), desc="Alphas", position=0)
    inner1 = tqdm(total=len(st_sample_id_l), desc=f"Sample", position=1)
    inner2 = tqdm(total=BOOTSTRAP_ROUNDS, desc=f"Bootstrap #", position=2)
    for alpha in BOOTSTRAP_ALPHAS:

        inner1.refresh()  # force print final state
        inner1.reset()  # reuse bar

        pred_sp_boostrap_d[alpha] = {}

        for sample_id in st_sample_id_l:
            inner2.refresh()  # force print final state
            inner2.reset()  # reuse bar

            pred_sp_boostrap_d[alpha][sample_id] = []
            for i in range(BOOTSTRAP_ROUNDS):
                print(f"Adversarial training for ST slide {sample_id}: ")
                embs, clssmodel, _ = da_cellfraction.train(
                    sc_mix_d["train"],
                    lab_mix_d["train"],
                    mat_sp_d["train"][sample_id],
                    alpha=alpha,
                    alpha_lr=5,
                    emb_dim=64,
                    batch_size=BATCH_SIZE,
                    n_iterations=N_ITER,
                    initial_train=True,
                    initial_train_epochs=10,
                    seed=i,
                )

                pred_sp_boostrap_d[alpha][sample_id].append(
                    clssmodel.predict(mat_sp_d["train"][sample_id])
                )
                inner2.update(1)
            inner1.update(1)
        outer.update(1)

else:
    # embs_d, clssmodel_d, clssmodel_noda_d = {}, {}, {}
    for sample_id in st_sample_id_l:
        print(f"Adversarial training for ST slide {sample_id}: ")
        if not os.path.isdir(os.path.join(advtrain_folder, sample_id)):
            os.makedirs(os.path.join(advtrain_folder, sample_id))
        if not os.path.isdir(os.path.join(pretrain_folder, sample_id)):
            os.makedirs(os.path.join(pretrain_folder, sample_id))
        embs, embs_noda, clssmodel, clssmodel_noda = da_cellfraction.train(
            sc_mix_d["train"],
            lab_mix_d["train"],
            mat_sp_d[sample_id]["train"],
            alpha=ALPHA,
            alpha_lr=5,
            emb_dim=64,
            batch_size=BATCH_SIZE,
            n_iterations=N_ITER,
            initial_train=True,
            initial_train_epochs=10,
            seed=int(sample_id),
        )
        # embs_d[sample_id] = embs
        # clssmodel_d[sample_id] = clssmodel
        # clssmodel_noda_d[sample_id] = clssmodel_noda

        # print(keras.losses.KLDivergence()(lab_mix_d["train"], clssmodel_noda.predict(sc_mix_d["train"])).numpy())
        # print(keras.losses.KLDivergence()(lab_mix_d["train"], clssmodel.predict(sc_mix_d["train"])).numpy())
        # Save model
        clssmodel_noda.save(os.path.join(pretrain_folder, sample_id, "final_model"))
        clssmodel.save(os.path.join(advtrain_folder, sample_id, "final_model"))

        embs_noda.save(os.path.join(pretrain_folder, sample_id, "embs"))
        embs.save(os.path.join(advtrain_folder, sample_id, "embs"))


Adversarial training for ST slide 151507: 
Train on 20000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
initial_train_done


  updates = self.state_updates


0.609762088394165
0.609762088394165
Iteration 99, source loss =  3.951, discriminator acc = 0.180
Iteration 199, source loss =  1.882, discriminator acc = 0.463
Iteration 299, source loss =  1.421, discriminator acc = 0.930
Iteration 399, source loss =  1.565, discriminator acc = 0.834
Iteration 499, source loss =  1.366, discriminator acc = 0.988
Iteration 599, source loss =  1.210, discriminator acc = 0.156
Iteration 699, source loss =  1.207, discriminator acc = 0.998
Iteration 799, source loss =  1.071, discriminator acc = 0.741
Iteration 899, source loss =  1.119, discriminator acc = 0.907
Iteration 999, source loss =  1.090, discriminator acc = 0.963
Iteration 1099, source loss =  1.032, discriminator acc = 0.190
Iteration 1199, source loss =  1.020, discriminator acc = 0.199
Iteration 1299, source loss =  1.007, discriminator acc = 0.172
Iteration 1399, source loss =  0.902, discriminator acc = 0.939
Iteration 1499, source loss =  0.985, discriminator acc = 0.796
Iteration 1599,

  updates = self.state_updates


0.6339802798271179
0.6339802798271179
Iteration 99, source loss =  4.053, discriminator acc = 0.067
Iteration 199, source loss =  2.089, discriminator acc = 0.856
Iteration 299, source loss =  2.219, discriminator acc = 0.821
Iteration 399, source loss =  1.455, discriminator acc = 0.856
Iteration 499, source loss =  1.598, discriminator acc = 0.843
Iteration 599, source loss =  1.428, discriminator acc = 0.835
Iteration 699, source loss =  1.336, discriminator acc = 0.836
Iteration 799, source loss =  1.223, discriminator acc = 0.880
Iteration 899, source loss =  1.190, discriminator acc = 0.956
Iteration 999, source loss =  1.423, discriminator acc = 0.968
Iteration 1099, source loss =  1.280, discriminator acc = 0.944
Iteration 1199, source loss =  1.148, discriminator acc = 0.917
Iteration 1299, source loss =  1.064, discriminator acc = 0.414
Iteration 1399, source loss =  1.040, discriminator acc = 0.728
Iteration 1499, source loss =  0.908, discriminator acc = 0.859
Iteration 159

  updates = self.state_updates


0.6195952402114868
0.6195952402114868
Iteration 99, source loss =  3.593, discriminator acc = 0.056
Iteration 199, source loss =  1.802, discriminator acc = 0.818
Iteration 299, source loss =  1.439, discriminator acc = 0.812
Iteration 399, source loss =  1.289, discriminator acc = 0.471
Iteration 499, source loss =  1.645, discriminator acc = 0.187
Iteration 599, source loss =  1.265, discriminator acc = 0.195
Iteration 699, source loss =  1.326, discriminator acc = 0.788
Iteration 799, source loss =  1.313, discriminator acc = 0.195
Iteration 899, source loss =  1.262, discriminator acc = 0.931
Iteration 999, source loss =  1.183, discriminator acc = 0.488
Iteration 1099, source loss =  1.184, discriminator acc = 0.871
Iteration 1199, source loss =  1.133, discriminator acc = 0.706
Iteration 1299, source loss =  1.137, discriminator acc = 0.999
Iteration 1399, source loss =  1.082, discriminator acc = 0.883
Iteration 1499, source loss =  0.995, discriminator acc = 0.997
Iteration 159

  updates = self.state_updates


0.6635840937614441
0.6635840937614441
Iteration 99, source loss =  3.144, discriminator acc = 0.698
Iteration 199, source loss =  1.824, discriminator acc = 0.876
Iteration 299, source loss =  1.579, discriminator acc = 0.824
Iteration 399, source loss =  1.781, discriminator acc = 0.281
Iteration 499, source loss =  1.429, discriminator acc = 0.992
Iteration 599, source loss =  1.299, discriminator acc = 0.871
Iteration 699, source loss =  1.148, discriminator acc = 0.792
Iteration 799, source loss =  1.150, discriminator acc = 0.959
Iteration 899, source loss =  0.954, discriminator acc = 0.896
Iteration 999, source loss =  0.886, discriminator acc = 0.277
Iteration 1099, source loss =  0.805, discriminator acc = 0.347
Iteration 1199, source loss =  0.775, discriminator acc = 0.733
Iteration 1299, source loss =  0.718, discriminator acc = 0.833
Iteration 1399, source loss =  0.786, discriminator acc = 0.907
Iteration 1499, source loss =  0.761, discriminator acc = 0.473
Iteration 159

  updates = self.state_updates


0.6345101404190063
0.6345101404190063
Iteration 99, source loss =  3.426, discriminator acc = 0.155
Iteration 199, source loss =  2.907, discriminator acc = 0.714
Iteration 299, source loss =  2.118, discriminator acc = 0.993
Iteration 399, source loss =  1.660, discriminator acc = 0.208
Iteration 499, source loss =  1.162, discriminator acc = 0.879
Iteration 599, source loss =  1.209, discriminator acc = 0.979
Iteration 699, source loss =  1.437, discriminator acc = 0.180
Iteration 799, source loss =  1.228, discriminator acc = 0.939
Iteration 899, source loss =  1.207, discriminator acc = 0.930
Iteration 999, source loss =  1.056, discriminator acc = 0.615
Iteration 1099, source loss =  0.919, discriminator acc = 0.976
Iteration 1199, source loss =  0.989, discriminator acc = 0.182
Iteration 1299, source loss =  0.841, discriminator acc = 0.892
Iteration 1399, source loss =  0.768, discriminator acc = 0.964
Iteration 1499, source loss =  0.784, discriminator acc = 0.181
Iteration 159

  updates = self.state_updates


0.6091950691223145
0.6091950691223145
Iteration 99, source loss =  3.522, discriminator acc = 0.226
Iteration 199, source loss =  2.607, discriminator acc = 0.594
Iteration 299, source loss =  1.480, discriminator acc = 0.145
Iteration 399, source loss =  1.309, discriminator acc = 0.182
Iteration 499, source loss =  1.486, discriminator acc = 0.149
Iteration 599, source loss =  1.137, discriminator acc = 0.311
Iteration 699, source loss =  1.404, discriminator acc = 0.476
Iteration 799, source loss =  1.285, discriminator acc = 0.149
Iteration 899, source loss =  1.105, discriminator acc = 0.515
Iteration 999, source loss =  1.093, discriminator acc = 0.328
Iteration 1099, source loss =  1.071, discriminator acc = 0.149
Iteration 1199, source loss =  1.082, discriminator acc = 0.785
Iteration 1299, source loss =  1.129, discriminator acc = 0.902
Iteration 1399, source loss =  1.205, discriminator acc = 0.974
Iteration 1499, source loss =  0.989, discriminator acc = 0.216
Iteration 159

  updates = self.state_updates


0.5910807642936706
0.5910807642936706
Iteration 99, source loss =  2.810, discriminator acc = 0.170
Iteration 199, source loss =  2.593, discriminator acc = 0.170
Iteration 299, source loss =  1.458, discriminator acc = 0.869
Iteration 399, source loss =  1.480, discriminator acc = 0.999
Iteration 499, source loss =  1.322, discriminator acc = 0.170
Iteration 599, source loss =  1.472, discriminator acc = 0.504
Iteration 699, source loss =  1.262, discriminator acc = 0.161
Iteration 799, source loss =  1.151, discriminator acc = 0.840
Iteration 899, source loss =  1.181, discriminator acc = 0.233
Iteration 999, source loss =  1.170, discriminator acc = 0.998
Iteration 1099, source loss =  1.086, discriminator acc = 0.533
Iteration 1199, source loss =  1.027, discriminator acc = 0.999
Iteration 1299, source loss =  1.072, discriminator acc = 0.225
Iteration 1399, source loss =  0.974, discriminator acc = 0.972
Iteration 1499, source loss =  0.965, discriminator acc = 0.636
Iteration 159

  updates = self.state_updates


0.5862498840808869
0.5862498840808869
Iteration 99, source loss =  2.423, discriminator acc = 0.073
Iteration 199, source loss =  1.438, discriminator acc = 0.151
Iteration 299, source loss =  1.460, discriminator acc = 0.939
Iteration 399, source loss =  1.515, discriminator acc = 0.510
Iteration 499, source loss =  1.371, discriminator acc = 0.287
Iteration 599, source loss =  1.519, discriminator acc = 0.461
Iteration 699, source loss =  1.265, discriminator acc = 0.947
Iteration 799, source loss =  1.090, discriminator acc = 0.813
Iteration 899, source loss =  1.070, discriminator acc = 0.871
Iteration 999, source loss =  1.272, discriminator acc = 0.351
Iteration 1099, source loss =  1.130, discriminator acc = 0.922
Iteration 1199, source loss =  1.117, discriminator acc = 0.768
Iteration 1299, source loss =  1.125, discriminator acc = 0.964
Iteration 1399, source loss =  1.041, discriminator acc = 0.267
Iteration 1499, source loss =  1.094, discriminator acc = 0.988
Iteration 159

  updates = self.state_updates


0.6047542397499085
0.6047542397499085
Iteration 99, source loss =  2.733, discriminator acc = 0.859
Iteration 199, source loss =  2.062, discriminator acc = 0.154
Iteration 299, source loss =  1.557, discriminator acc = 0.624
Iteration 399, source loss =  1.482, discriminator acc = 0.216
Iteration 499, source loss =  1.317, discriminator acc = 0.901
Iteration 599, source loss =  1.291, discriminator acc = 0.882
Iteration 699, source loss =  1.132, discriminator acc = 0.150
Iteration 799, source loss =  1.150, discriminator acc = 0.969
Iteration 899, source loss =  1.127, discriminator acc = 0.644
Iteration 999, source loss =  1.154, discriminator acc = 0.969
Iteration 1099, source loss =  1.241, discriminator acc = 0.162
Iteration 1199, source loss =  1.066, discriminator acc = 0.961
Iteration 1299, source loss =  0.996, discriminator acc = 0.436
Iteration 1399, source loss =  1.062, discriminator acc = 0.639
Iteration 1499, source loss =  0.979, discriminator acc = 0.937
Iteration 159

  updates = self.state_updates


0.6114915532588959
0.6114915532588959
Iteration 99, source loss =  4.454, discriminator acc = 0.155
Iteration 199, source loss =  2.480, discriminator acc = 0.155
Iteration 299, source loss =  1.531, discriminator acc = 0.856
Iteration 399, source loss =  1.429, discriminator acc = 0.945
Iteration 499, source loss =  1.437, discriminator acc = 0.877
Iteration 599, source loss =  1.609, discriminator acc = 0.156
Iteration 699, source loss =  1.433, discriminator acc = 0.834
Iteration 799, source loss =  1.155, discriminator acc = 0.844
Iteration 899, source loss =  1.249, discriminator acc = 0.970
Iteration 999, source loss =  1.049, discriminator acc = 0.528
Iteration 1099, source loss =  1.021, discriminator acc = 0.796
Iteration 1199, source loss =  0.993, discriminator acc = 0.371
Iteration 1299, source loss =  0.976, discriminator acc = 0.492
Iteration 1399, source loss =  0.809, discriminator acc = 0.790
Iteration 1499, source loss =  0.789, discriminator acc = 0.996
Iteration 159

  updates = self.state_updates


0.636063422203064
0.636063422203064
Iteration 99, source loss =  2.281, discriminator acc = 0.137
Iteration 199, source loss =  2.356, discriminator acc = 0.762
Iteration 299, source loss =  1.645, discriminator acc = 0.880
Iteration 399, source loss =  1.667, discriminator acc = 0.877
Iteration 499, source loss =  1.350, discriminator acc = 0.874
Iteration 599, source loss =  1.550, discriminator acc = 0.888
Iteration 699, source loss =  1.321, discriminator acc = 0.901
Iteration 799, source loss =  1.289, discriminator acc = 0.882
Iteration 899, source loss =  1.251, discriminator acc = 0.873
Iteration 999, source loss =  1.289, discriminator acc = 0.882
Iteration 1099, source loss =  1.243, discriminator acc = 0.995
Iteration 1199, source loss =  1.297, discriminator acc = 0.948
Iteration 1299, source loss =  1.465, discriminator acc = 0.483
Iteration 1399, source loss =  1.354, discriminator acc = 0.883
Iteration 1499, source loss =  1.048, discriminator acc = 0.863
Iteration 1599,

  updates = self.state_updates


0.6267835119247437
0.6267835119247437
Iteration 99, source loss =  2.441, discriminator acc = 0.847
Iteration 199, source loss =  1.537, discriminator acc = 0.865
Iteration 299, source loss =  1.552, discriminator acc = 0.855
Iteration 399, source loss =  1.503, discriminator acc = 0.871
Iteration 499, source loss =  1.373, discriminator acc = 0.883
Iteration 599, source loss =  1.590, discriminator acc = 0.873
Iteration 699, source loss =  1.298, discriminator acc = 0.885
Iteration 799, source loss =  1.325, discriminator acc = 0.875
Iteration 899, source loss =  1.265, discriminator acc = 0.878
Iteration 999, source loss =  1.179, discriminator acc = 0.887
Iteration 1099, source loss =  1.233, discriminator acc = 0.974
Iteration 1199, source loss =  1.315, discriminator acc = 0.526
Iteration 1299, source loss =  1.381, discriminator acc = 0.986
Iteration 1399, source loss =  1.064, discriminator acc = 0.959
Iteration 1499, source loss =  1.107, discriminator acc = 0.372
Iteration 159

In [None]:
# def jsd(y_true, y_pred):
#     return tf.keras.losses.kullback_leibler_divergence(y_true, y_pred)


In [None]:
# clssmodel.compile(
#         optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
#         loss={'mo': 'kld'},
#         metrics=['mae'],
#     )

# pred = clssmodel.evaluate(sc_mix_d["train"], lab_mix_d["train"])


In [None]:
# pred


In [None]:
# clssmodel_noda.compile(
#         optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
#         loss={'mo': 'kld'},
#         metrics=['mae'],
#     )

# clssmodel_noda.evaluate(sc_mix_d["train"], lab_mix_d["train"])


In [None]:
# confidence = 0.95
# if BOOTSTRAP:
#     for alpha in BOOTSTRAP_ALPHAS:
#         bootstrap_d = {}
#         for sample_id in st_sample_id_l:


#             bootstrap_d[sample_id] = {}
#             for i, num in enumerate(numlist):

#                 acc = [metrics.roc_auc_score(*gen_pred_true(num, adata_spatialLIBD_d[SAMPLE_ID_N], pred)[::-1]) for pred in pred_sp_boostrap_d[alpha][SAMPLE_ID_N]]
#                 test_mean = np.mean(acc)
#                 t_value = ss.t.ppf((1 + confidence) / 2.0, df=BOOTSTRAP_ROUNDS - 1)

#                 sd = np.std(acc, ddof=1)
#                 se = sd / np.sqrt(BOOTSTRAP_ROUNDS)

#                 ci_length = t_value * se

#                 ci_lower = test_mean - ci_length
#                 ci_upper = test_mean + ci_length

#                 bootstrap_d[sample_id][num_to_ex_d[num]] = (ci_lower, test_mean, ci_upper)


#         bootstrap_df = pd.DataFrame.from_dict(bootstrap_d)
#         display(bootstrap_df)
#         bootstrap_df.to_csv(os.path.join(results_folder, f'bootstrap_alpha{alpha}.csv'))


## 4. Predict cell fraction of spots and visualization

In [None]:
# pred_sp_d, pred_sp_noda_d = {}, {}
# if TRAIN_USING_ALL_ST_SAMPLES:
#     for sample_id in st_sample_id_l:
#         pred_sp_d[sample_id] = clssmodel.predict(mat_sp_test_s_d[sample_id])
#         pred_sp_noda_d[sample_id] = clssmodel_noda.predict(mat_sp_test_s_d[sample_id])
# else:
#     for sample_id in st_sample_id_l:
#         pred_sp_d[sample_id] = clssmodel_d[sample_id].predict(
#             mat_sp_test_s_d[sample_id]
#         )
#         pred_sp_noda_d[sample_id] = clssmodel_noda_d[sample_id].predict(
#             mat_sp_test_s_d[sample_id]
#         )


In [None]:
# def plot_cellfraction(visnum, adata, pred_sp, ax=None):
#     """Plot predicted cell fraction for a given visnum"""
#     adata.obs["Pred_label"] = pred_sp[:, visnum]
#     # vmin = 0
#     # vmax = np.amax(pred_sp)

#     sc.pl.spatial(
#         adata,
#         img_key="hires",
#         color="Pred_label",
#         palette="Set1",
#         size=1.5,
#         legend_loc=None,
#         title=f"{sc_sub_dict[visnum]}",
#         spot_size=100,
#         show=False,
#         # vmin=vmin,
#         # vmax=vmax,
#         ax=ax,
#     )


In [None]:
# def plot_cell_layers(df):

#     layer_idx = df["spatialLIBD"].unique().astype(str)
#     samples = df["sample_id"].unique()
#     layer_idx.sort()
#     fig, ax = plt.subplots(
#         nrows=1,
#         ncols=len(samples),
#         figsize=(5 * len(samples), 5),
#         squeeze=False,
#         constrained_layout=True,
#     )

#     for idx, sample in enumerate(samples):
#         cells_of_samples = df[df["sample_id"] == sample]
#         for index in layer_idx:
#             cells_of_layer = cells_of_samples[cells_of_samples["spatialLIBD"] == index]
#             ax.flat[idx].scatter(
#                 cells_of_layer["X"], -cells_of_layer["Y"], label=index, s=17, marker="o"
#             )

#         ax.flat[idx].axis("equal")
#         ax.flat[idx].set_xticks([])
#         ax.flat[idx].set_yticks([])
#         ax.flat[idx].set_title(sample)

#     plt.legend()
#     plt.show()


In [None]:
# def plot_roc(visnum, adata, pred_sp, name, ax=None):
#     """Plot ROC for a given visnum"""

#     def layer_to_layer_number(x):
#         for char in x:
#             if char.isdigit():
#                 if int(char) in Ex_to_L_d[num_to_ex_d[visnum]]:
#                     return 1
#         return 0

#     y_pred = pred_sp[:, visnum]
#     y_true = adata.obs["spatialLIBD"].map(layer_to_layer_number).fillna(0)
#     # print(y_true)
#     # print(y_true.isna().sum())
#     RocCurveDisplay.from_predictions(y_true=y_true, y_pred=y_pred, name=name, ax=ax)


In [None]:
# # plot_cell_layers(adata_spatialLIBD_151673.obs)

# fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5, 5), constrained_layout=True)

# sc.pl.spatial(
#     adata_spatialLIBD_d[SAMPLE_ID_N],
#     img_key=None,
#     color="spatialLIBD",
#     palette="Accent_r",
#     size=1.5,
#     title=SAMPLE_ID_N,
#     # legend_loc = 4,
#     spot_size=100,
#     show=False,
#     ax=ax,
# )

# ax.axis("equal")
# ax.set_xlabel("")
# ax.set_ylabel("")

# fig.show()


In [None]:
# fig, ax = plt.subplots(2, 5, figsize=(20, 8), constrained_layout=True)

# for i, num in enumerate(numlist):
#     plot_cellfraction(
#         num, adata_spatialLIBD_d[SAMPLE_ID_N], pred_sp_d[SAMPLE_ID_N], ax.flat[i]
#     )
#     ax.flat[i].axis("equal")
#     ax.flat[i].set_xlabel("")
#     ax.flat[i].set_ylabel("")

# fig.show()
# # plt.close()

# fig, ax = plt.subplots(
#     2, 5, figsize=(20, 8), constrained_layout=True, sharex=True, sharey=True
# )

# for i, num in enumerate(numlist):
#     plot_roc(
#         num,
#         adata_spatialLIBD_d[SAMPLE_ID_N],
#         pred_sp_d[SAMPLE_ID_N],
#         "CellDART",
#         ax.flat[i],
#     )
#     plot_roc(
#         num,
#         adata_spatialLIBD_d[SAMPLE_ID_N],
#         pred_sp_noda_d[SAMPLE_ID_N],
#         "NN_wo_da",
#         ax.flat[i],
#     )
#     ax.flat[i].plot([0, 1], [0, 1], transform=ax.flat[i].transAxes, ls="--", color="k")
#     ax.flat[i].set_aspect("equal")
#     ax.flat[i].set_xlim([0, 1])
#     ax.flat[i].set_ylim([0, 1])

#     ax.flat[i].set_title(f"{sc_sub_dict[num]}")

#     if i >= len(numlist) - 5:
#         ax.flat[i].set_xlabel("FPR")
#     else:
#         ax.flat[i].set_xlabel("")
#     if i % 5 == 0:
#         ax.flat[i].set_ylabel("TPR")
#     else:
#         ax.flat[i].set_ylabel("")

# fig.show()
# # plt.close()


- cf. Prediction of Mixture (pseudospots)


In [None]:
# if TRAIN_USING_ALL_ST_SAMPLES:
#     pred_mix = clssmodel.predict(sc_mix_test_s)
# else:
#     pred_mix = clssmodel_d[SAMPLE_ID_N].predict(sc_mix_test_s)


# cell_type_nums = sc_sub_dict.keys()
# nrows = ceil(len(cell_type_nums) / 5)

# line_kws = {"color": "tab:orange"}
# scatter_kws = {"s": 5}

# props = dict(facecolor="w", alpha=0.5)

# fig, ax = plt.subplots(
#     nrows,
#     5,
#     figsize=(20, 4 * nrows),
#     constrained_layout=True,
#     sharex=False,
#     sharey=True,
# )
# for i, num in enumerate(cell_type_nums):
#     sns.regplot(
#         x=pred_mix[:, num],
#         y=lab_mix_test[:, num],
#         line_kws=line_kws,
#         scatter_kws=scatter_kws,
#         ax=ax.flat[i],
#     ).set_title(sc_sub_dict[num])
#     ax.flat[i].set_aspect("equal")

#     ax.flat[i].set_xlabel("Predicted Proportion")
#     if i % 5 == 0:
#         ax.flat[i].set_ylabel("True Proportion")
#     else:
#         ax.flat[i].set_ylabel("")
#     ax.flat[i].set_xlim([0, 1])
#     ax.flat[i].set_ylim([0, 1])

#     textstr = f"MSE: {mean_squared_error(pred_mix[:,num], lab_mix_test[:,num]):.5f}"

#     # place a text box in upper left in axes coords
#     ax.flat[i].text(
#         0.95,
#         0.05,
#         textstr,
#         transform=ax.flat[i].transAxes,
#         verticalalignment="bottom",
#         horizontalalignment="right",
#         bbox=props,
#     )

# for i in range(len(cell_type_nums), nrows * 5):
#     ax.flat[i].axis("off")

# plt.show()


In [None]:
# print(
#     "\n".join(
#         f"{m.__name__} {m.__version__}"
#         for m in globals().values()
#         if getattr(m, "__version__", None)
#     )
# )
