# CellDART Example Code: mouse brain 
## (10x Visium of anterior mouse brain + scRNA-seq data of mouse brain)

In [1]:
import glob
import os
import pickle
from math import ceil

import anndata as ad
import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import scanpy as sc
import seaborn as sns
import scipy.stats as ss

from sklearn.model_selection import train_test_split
from sklearn.metrics import RocCurveDisplay
from sklearn.metrics import mean_squared_error
from sklearn import metrics

import tensorflow as tf  # TensorFlow registers PluggableDevices here.

from CellDART import da_cellfraction
from CellDART.utils import random_mix

from src.utils.data_loading import load_spatial, load_sc, get_selected_dir

from tqdm.autonotebook import tqdm


  from tqdm.autonotebook import tqdm


In [2]:
tf.config.list_physical_devices()


[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [3]:
TRAIN_USING_ALL_ST_SAMPLES = False
N_MARKERS = 20
ALL_GENES = False

ST_SPLIT = False
N_SPOTS = 20000
N_MIX = 8
SCALER_NAME = "celldart"


SAMPLE_ID_N = "151673"
INITIAL_TRAIN_EPOCHS = 10


BATCH_SIZE = 512
ALPHA = 0.6
ALPHA_LR = 5
N_ITER = 3000

DATA_DIR = "../AGrEDA/data"
# DATA_DIR = "./data_combine/"

BOOTSTRAP = False
BOOTSTRAP_ROUNDS = 10
BOOTSTRAP_ALPHAS = [0.6, 1 / 0.6]

MODEL_NAME = "CellDART"
MODEL_VERSION = "V1"


In [4]:
model_folder = os.path.join("model", MODEL_NAME, MODEL_VERSION)

if not os.path.isdir(model_folder):
    os.makedirs(model_folder)
    print(model_folder)


model/CellDART/V1


## 1. Data load
### load scanpy data - 10x datasets

In [5]:
# Load spatial data
mat_sp_d, mat_sp_train, st_sample_id_l = load_spatial(
    get_selected_dir(DATA_DIR, N_MARKERS, ALL_GENES),
    SCALER_NAME,
    train_using_all_st_samples=TRAIN_USING_ALL_ST_SAMPLES,
    st_split=ST_SPLIT,
)

# Load sc data
sc_mix_d, lab_mix_d, sc_sub_dict, sc_sub_dict2 = load_sc(
    get_selected_dir(DATA_DIR, N_MARKERS, ALL_GENES),
    SCALER_NAME,
    n_mix=N_MIX,
    n_spots=N_SPOTS,
)


In [6]:
advtrain_folder = os.path.join(model_folder, "advtrain")
pretrain_folder = os.path.join(model_folder, "pretrain")
if not os.path.isdir(advtrain_folder):
    os.makedirs(advtrain_folder)
if not os.path.isdir(pretrain_folder):
    os.makedirs(pretrain_folder)


In [7]:
# st_sample_id_l = [SAMPLE_ID_N]


In [8]:
if TRAIN_USING_ALL_ST_SAMPLES:
    print(f"Adversarial training for all ST slides")
    embs, clssmodel, clssmodel_noda = da_cellfraction.train(
        sc_mix_d["train"],
        lab_mix_d["train"],
        mat_sp_train_s,
        alpha=ALPHA,
        alpha_lr=5,
        emb_dim=64,
        batch_size=BATCH_SIZE,
        n_iterations=N_ITER,
        initial_train=True,
        initial_train_epochs=INITIAL_TRAIN_EPOCHS,
    )
elif BOOTSTRAP:
    pred_sp_boostrap_d = {}

    outer = tqdm(total=len(BOOTSTRAP_ALPHAS), desc="Alphas", position=0)
    inner1 = tqdm(total=len(st_sample_id_l), desc=f"Sample", position=1)
    inner2 = tqdm(total=BOOTSTRAP_ROUNDS, desc=f"Bootstrap #", position=2)
    for alpha in BOOTSTRAP_ALPHAS:

        inner1.refresh()  # force print final state
        inner1.reset()  # reuse bar

        pred_sp_boostrap_d[alpha] = {}

        for sample_id in st_sample_id_l:
            inner2.refresh()  # force print final state
            inner2.reset()  # reuse bar

            pred_sp_boostrap_d[alpha][sample_id] = []
            for i in range(BOOTSTRAP_ROUNDS):
                print(f"Adversarial training for ST slide {sample_id}: ")
                embs, clssmodel, _ = da_cellfraction.train(
                    sc_mix_d["train"],
                    lab_mix_d["train"],
                    mat_sp_d["train"][sample_id],
                    alpha=alpha,
                    alpha_lr=5,
                    emb_dim=64,
                    batch_size=BATCH_SIZE,
                    n_iterations=N_ITER,
                    initial_train=True,
                    initial_train_epochs=10,
                    seed=i,
                )

                pred_sp_boostrap_d[alpha][sample_id].append(
                    clssmodel.predict(mat_sp_d["train"][sample_id])
                )
                inner2.update(1)
            inner1.update(1)
        outer.update(1)

else:
    # embs_d, clssmodel_d, clssmodel_noda_d = {}, {}, {}
    for sample_id in st_sample_id_l:
        print(f"Adversarial training for ST slide {sample_id}: ")
        if not os.path.isdir(os.path.join(advtrain_folder, sample_id)):
            os.makedirs(os.path.join(advtrain_folder, sample_id))
        if not os.path.isdir(os.path.join(pretrain_folder, sample_id)):
            os.makedirs(os.path.join(pretrain_folder, sample_id))
        embs, embs_noda, clssmodel, clssmodel_noda = da_cellfraction.train(
            sc_mix_d["train"],
            lab_mix_d["train"],
            mat_sp_d[sample_id]["train"],
            alpha=ALPHA,
            alpha_lr=5,
            emb_dim=64,
            batch_size=BATCH_SIZE,
            n_iterations=N_ITER,
            initial_train=True,
            initial_train_epochs=10,
            seed=int(sample_id),
        )
        # embs_d[sample_id] = embs
        # clssmodel_d[sample_id] = clssmodel
        # clssmodel_noda_d[sample_id] = clssmodel_noda

        # print(keras.losses.KLDivergence()(lab_mix_d["train"], clssmodel_noda.predict(sc_mix_d["train"])).numpy())
        # print(keras.losses.KLDivergence()(lab_mix_d["train"], clssmodel.predict(sc_mix_d["train"])).numpy())
        # Save model
        clssmodel_noda.save(os.path.join(pretrain_folder, sample_id, "final_model"))
        clssmodel.save(os.path.join(advtrain_folder, sample_id, "final_model"))

        embs_noda.save(os.path.join(pretrain_folder, sample_id, "embs"))
        embs.save(os.path.join(advtrain_folder, sample_id, "embs"))


Adversarial training for ST slide 151507: 
Train on 20000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
initial_train_done


  updates = self.state_updates


0.6097667330265045
0.6097667330265045
Iteration 99, source loss =  3.951, discriminator acc = 0.180
Iteration 199, source loss =  1.888, discriminator acc = 0.455
Iteration 299, source loss =  1.426, discriminator acc = 0.928
Iteration 399, source loss =  1.545, discriminator acc = 0.866
Iteration 499, source loss =  1.364, discriminator acc = 0.988
Iteration 599, source loss =  1.210, discriminator acc = 0.154
Iteration 699, source loss =  1.252, discriminator acc = 0.996
Iteration 799, source loss =  1.068, discriminator acc = 0.740
Iteration 899, source loss =  1.313, discriminator acc = 0.884
Iteration 999, source loss =  1.078, discriminator acc = 0.961
Iteration 1099, source loss =  0.912, discriminator acc = 0.992
Iteration 1199, source loss =  0.901, discriminator acc = 0.128
Iteration 1299, source loss =  0.814, discriminator acc = 0.246
Iteration 1399, source loss =  0.813, discriminator acc = 0.349
Iteration 1499, source loss =  0.787, discriminator acc = 0.143
Iteration 159

  updates = self.state_updates


0.6339968461990356
0.6339968461990356
Iteration 99, source loss =  4.053, discriminator acc = 0.066
Iteration 199, source loss =  2.089, discriminator acc = 0.855
Iteration 299, source loss =  2.219, discriminator acc = 0.821
Iteration 399, source loss =  1.457, discriminator acc = 0.856
Iteration 499, source loss =  1.600, discriminator acc = 0.843
Iteration 599, source loss =  1.428, discriminator acc = 0.835
Iteration 699, source loss =  1.335, discriminator acc = 0.836
Iteration 799, source loss =  1.223, discriminator acc = 0.879
Iteration 899, source loss =  1.190, discriminator acc = 0.953
Iteration 999, source loss =  1.421, discriminator acc = 0.976
Iteration 1099, source loss =  1.285, discriminator acc = 0.951
Iteration 1199, source loss =  1.171, discriminator acc = 0.897
Iteration 1299, source loss =  1.066, discriminator acc = 0.455
Iteration 1399, source loss =  1.040, discriminator acc = 0.789
Iteration 1499, source loss =  0.915, discriminator acc = 0.872
Iteration 159

  updates = self.state_updates


0.6196031370162964
0.6196031370162964
Iteration 99, source loss =  3.593, discriminator acc = 0.056
Iteration 199, source loss =  1.802, discriminator acc = 0.818
Iteration 299, source loss =  1.440, discriminator acc = 0.812
Iteration 399, source loss =  1.290, discriminator acc = 0.473
Iteration 499, source loss =  1.647, discriminator acc = 0.187
Iteration 599, source loss =  1.279, discriminator acc = 0.194
Iteration 699, source loss =  1.316, discriminator acc = 0.788
Iteration 799, source loss =  1.314, discriminator acc = 0.198
Iteration 899, source loss =  1.266, discriminator acc = 0.951
Iteration 999, source loss =  1.178, discriminator acc = 0.508
Iteration 1099, source loss =  1.168, discriminator acc = 0.881
Iteration 1199, source loss =  1.126, discriminator acc = 0.699
Iteration 1299, source loss =  1.141, discriminator acc = 0.998
Iteration 1399, source loss =  1.052, discriminator acc = 0.897
Iteration 1499, source loss =  0.993, discriminator acc = 0.997
Iteration 159

  updates = self.state_updates


0.6635860331535339
0.6635860331535339
Iteration 99, source loss =  3.144, discriminator acc = 0.698
Iteration 199, source loss =  1.824, discriminator acc = 0.876
Iteration 299, source loss =  1.579, discriminator acc = 0.824
Iteration 399, source loss =  1.782, discriminator acc = 0.281
Iteration 499, source loss =  1.430, discriminator acc = 0.993
Iteration 599, source loss =  1.298, discriminator acc = 0.871
Iteration 699, source loss =  1.148, discriminator acc = 0.792
Iteration 799, source loss =  1.150, discriminator acc = 0.959
Iteration 899, source loss =  0.955, discriminator acc = 0.896
Iteration 999, source loss =  0.886, discriminator acc = 0.279
Iteration 1099, source loss =  0.805, discriminator acc = 0.348
Iteration 1199, source loss =  0.777, discriminator acc = 0.737
Iteration 1299, source loss =  0.720, discriminator acc = 0.853
Iteration 1399, source loss =  0.790, discriminator acc = 0.908
Iteration 1499, source loss =  0.761, discriminator acc = 0.463
Iteration 159

  updates = self.state_updates


0.6344990928649903
0.6344990928649903
Iteration 99, source loss =  3.426, discriminator acc = 0.155
Iteration 199, source loss =  2.906, discriminator acc = 0.714
Iteration 299, source loss =  2.118, discriminator acc = 0.993
Iteration 399, source loss =  1.659, discriminator acc = 0.209
Iteration 499, source loss =  1.161, discriminator acc = 0.878
Iteration 599, source loss =  1.207, discriminator acc = 0.979
Iteration 699, source loss =  1.436, discriminator acc = 0.179
Iteration 799, source loss =  1.226, discriminator acc = 0.938
Iteration 899, source loss =  1.208, discriminator acc = 0.931
Iteration 999, source loss =  1.054, discriminator acc = 0.626
Iteration 1099, source loss =  0.918, discriminator acc = 0.977
Iteration 1199, source loss =  0.988, discriminator acc = 0.187
Iteration 1299, source loss =  0.844, discriminator acc = 0.889
Iteration 1399, source loss =  0.768, discriminator acc = 0.946
Iteration 1499, source loss =  0.783, discriminator acc = 0.218
Iteration 159

  updates = self.state_updates


0.60918740940094
0.60918740940094
Iteration 99, source loss =  3.523, discriminator acc = 0.226
Iteration 199, source loss =  2.602, discriminator acc = 0.594
Iteration 299, source loss =  1.483, discriminator acc = 0.145
Iteration 399, source loss =  1.395, discriminator acc = 0.148
Iteration 499, source loss =  1.473, discriminator acc = 0.149
Iteration 599, source loss =  1.129, discriminator acc = 0.495
Iteration 699, source loss =  1.318, discriminator acc = 0.150
Iteration 799, source loss =  1.202, discriminator acc = 0.149
Iteration 899, source loss =  1.108, discriminator acc = 0.903
Iteration 999, source loss =  1.146, discriminator acc = 0.805
Iteration 1099, source loss =  1.218, discriminator acc = 0.144
Iteration 1199, source loss =  1.199, discriminator acc = 0.034
Iteration 1299, source loss =  1.059, discriminator acc = 0.998
Iteration 1399, source loss =  1.183, discriminator acc = 0.165
Iteration 1499, source loss =  0.968, discriminator acc = 0.998
Iteration 1599, s

  updates = self.state_updates


0.5910779253959656
0.5910779253959656
Iteration 99, source loss =  2.809, discriminator acc = 0.170
Iteration 199, source loss =  2.593, discriminator acc = 0.170
Iteration 299, source loss =  1.450, discriminator acc = 0.868
Iteration 399, source loss =  1.481, discriminator acc = 1.000
Iteration 499, source loss =  1.333, discriminator acc = 0.170
Iteration 599, source loss =  1.454, discriminator acc = 0.706
Iteration 699, source loss =  1.343, discriminator acc = 0.158
Iteration 799, source loss =  1.082, discriminator acc = 0.919
Iteration 899, source loss =  1.159, discriminator acc = 0.164
Iteration 999, source loss =  1.114, discriminator acc = 0.999
Iteration 1099, source loss =  1.123, discriminator acc = 0.272
Iteration 1199, source loss =  1.019, discriminator acc = 1.000
Iteration 1299, source loss =  1.019, discriminator acc = 0.202
Iteration 1399, source loss =  0.998, discriminator acc = 0.976
Iteration 1499, source loss =  1.097, discriminator acc = 0.148
Iteration 159

  updates = self.state_updates


0.5862530634880065
0.5862530634880065
Iteration 99, source loss =  2.424, discriminator acc = 0.072
Iteration 199, source loss =  1.450, discriminator acc = 0.151
Iteration 299, source loss =  1.466, discriminator acc = 0.939
Iteration 399, source loss =  1.517, discriminator acc = 0.520
Iteration 499, source loss =  1.378, discriminator acc = 0.271
Iteration 599, source loss =  1.388, discriminator acc = 0.425
Iteration 699, source loss =  1.305, discriminator acc = 0.995
Iteration 799, source loss =  1.098, discriminator acc = 0.902
Iteration 899, source loss =  1.123, discriminator acc = 0.849
Iteration 999, source loss =  1.209, discriminator acc = 0.685
Iteration 1099, source loss =  1.093, discriminator acc = 0.714
Iteration 1199, source loss =  1.188, discriminator acc = 0.659
Iteration 1299, source loss =  1.255, discriminator acc = 0.717
Iteration 1399, source loss =  1.043, discriminator acc = 0.467
Iteration 1499, source loss =  1.022, discriminator acc = 1.000
Iteration 159

  updates = self.state_updates


0.6047545536518097
0.6047545536518097
Iteration 99, source loss =  2.733, discriminator acc = 0.859
Iteration 199, source loss =  2.061, discriminator acc = 0.154
Iteration 299, source loss =  1.557, discriminator acc = 0.622
Iteration 399, source loss =  1.482, discriminator acc = 0.220
Iteration 499, source loss =  1.317, discriminator acc = 0.900
Iteration 599, source loss =  1.289, discriminator acc = 0.906
Iteration 699, source loss =  1.125, discriminator acc = 0.151
Iteration 799, source loss =  1.163, discriminator acc = 0.974
Iteration 899, source loss =  1.112, discriminator acc = 0.719
Iteration 999, source loss =  1.195, discriminator acc = 0.953
Iteration 1099, source loss =  1.334, discriminator acc = 0.158
Iteration 1199, source loss =  1.056, discriminator acc = 0.918
Iteration 1299, source loss =  1.013, discriminator acc = 0.408
Iteration 1399, source loss =  1.055, discriminator acc = 0.681
Iteration 1499, source loss =  0.963, discriminator acc = 0.849
Iteration 159

  updates = self.state_updates


0.6115017732143402
0.6115017732143402
Iteration 99, source loss =  4.454, discriminator acc = 0.155
Iteration 199, source loss =  2.479, discriminator acc = 0.155
Iteration 299, source loss =  1.531, discriminator acc = 0.856
Iteration 399, source loss =  1.428, discriminator acc = 0.943
Iteration 499, source loss =  1.438, discriminator acc = 0.877
Iteration 599, source loss =  1.599, discriminator acc = 0.157
Iteration 699, source loss =  1.439, discriminator acc = 0.807
Iteration 799, source loss =  1.169, discriminator acc = 0.835
Iteration 899, source loss =  1.242, discriminator acc = 0.968
Iteration 999, source loss =  1.050, discriminator acc = 0.591
Iteration 1099, source loss =  1.024, discriminator acc = 0.819
Iteration 1199, source loss =  0.997, discriminator acc = 0.331
Iteration 1299, source loss =  0.998, discriminator acc = 0.493
Iteration 1399, source loss =  0.814, discriminator acc = 0.932
Iteration 1499, source loss =  0.788, discriminator acc = 0.989
Iteration 159

  updates = self.state_updates


0.6360812035560608
0.6360812035560608
Iteration 99, source loss =  2.281, discriminator acc = 0.137
Iteration 199, source loss =  2.357, discriminator acc = 0.762
Iteration 299, source loss =  1.646, discriminator acc = 0.880
Iteration 399, source loss =  1.667, discriminator acc = 0.877
Iteration 499, source loss =  1.350, discriminator acc = 0.874
Iteration 599, source loss =  1.551, discriminator acc = 0.888
Iteration 699, source loss =  1.321, discriminator acc = 0.901
Iteration 799, source loss =  1.287, discriminator acc = 0.882
Iteration 899, source loss =  1.251, discriminator acc = 0.873
Iteration 999, source loss =  1.286, discriminator acc = 0.882
Iteration 1099, source loss =  1.243, discriminator acc = 0.996
Iteration 1199, source loss =  1.296, discriminator acc = 0.949
Iteration 1299, source loss =  1.459, discriminator acc = 0.491
Iteration 1399, source loss =  1.342, discriminator acc = 0.883
Iteration 1499, source loss =  1.044, discriminator acc = 0.855
Iteration 159

  updates = self.state_updates


0.6268076364517212
0.6268076364517212
Iteration 99, source loss =  2.441, discriminator acc = 0.847
Iteration 199, source loss =  1.537, discriminator acc = 0.865
Iteration 299, source loss =  1.552, discriminator acc = 0.855
Iteration 399, source loss =  1.503, discriminator acc = 0.871
Iteration 499, source loss =  1.373, discriminator acc = 0.883
Iteration 599, source loss =  1.589, discriminator acc = 0.873
Iteration 699, source loss =  1.298, discriminator acc = 0.885
Iteration 799, source loss =  1.324, discriminator acc = 0.875
Iteration 899, source loss =  1.265, discriminator acc = 0.878
Iteration 999, source loss =  1.180, discriminator acc = 0.887
Iteration 1099, source loss =  1.234, discriminator acc = 0.974
Iteration 1199, source loss =  1.316, discriminator acc = 0.524
Iteration 1299, source loss =  1.381, discriminator acc = 0.987
Iteration 1399, source loss =  1.063, discriminator acc = 0.959
Iteration 1499, source loss =  1.106, discriminator acc = 0.371
Iteration 159

In [9]:
# def jsd(y_true, y_pred):
#     return tf.keras.losses.kullback_leibler_divergence(y_true, y_pred)


In [10]:
# clssmodel.compile(
#         optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
#         loss={'mo': 'kld'},
#         metrics=['mae'],
#     )

# pred = clssmodel.evaluate(sc_mix_d["train"], lab_mix_d["train"])


In [11]:
# pred


In [12]:
# clssmodel_noda.compile(
#         optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
#         loss={'mo': 'kld'},
#         metrics=['mae'],
#     )

# clssmodel_noda.evaluate(sc_mix_d["train"], lab_mix_d["train"])


In [13]:
# confidence = 0.95
# if BOOTSTRAP:
#     for alpha in BOOTSTRAP_ALPHAS:
#         bootstrap_d = {}
#         for sample_id in st_sample_id_l:


#             bootstrap_d[sample_id] = {}
#             for i, num in enumerate(numlist):

#                 acc = [metrics.roc_auc_score(*gen_pred_true(num, adata_spatialLIBD_d[SAMPLE_ID_N], pred)[::-1]) for pred in pred_sp_boostrap_d[alpha][SAMPLE_ID_N]]
#                 test_mean = np.mean(acc)
#                 t_value = ss.t.ppf((1 + confidence) / 2.0, df=BOOTSTRAP_ROUNDS - 1)

#                 sd = np.std(acc, ddof=1)
#                 se = sd / np.sqrt(BOOTSTRAP_ROUNDS)

#                 ci_length = t_value * se

#                 ci_lower = test_mean - ci_length
#                 ci_upper = test_mean + ci_length

#                 bootstrap_d[sample_id][num_to_ex_d[num]] = (ci_lower, test_mean, ci_upper)


#         bootstrap_df = pd.DataFrame.from_dict(bootstrap_d)
#         display(bootstrap_df)
#         bootstrap_df.to_csv(os.path.join(results_folder, f'bootstrap_alpha{alpha}.csv'))


## 4. Predict cell fraction of spots and visualization

In [14]:
# pred_sp_d, pred_sp_noda_d = {}, {}
# if TRAIN_USING_ALL_ST_SAMPLES:
#     for sample_id in st_sample_id_l:
#         pred_sp_d[sample_id] = clssmodel.predict(mat_sp_test_s_d[sample_id])
#         pred_sp_noda_d[sample_id] = clssmodel_noda.predict(mat_sp_test_s_d[sample_id])
# else:
#     for sample_id in st_sample_id_l:
#         pred_sp_d[sample_id] = clssmodel_d[sample_id].predict(
#             mat_sp_test_s_d[sample_id]
#         )
#         pred_sp_noda_d[sample_id] = clssmodel_noda_d[sample_id].predict(
#             mat_sp_test_s_d[sample_id]
#         )


In [15]:
# def plot_cellfraction(visnum, adata, pred_sp, ax=None):
#     """Plot predicted cell fraction for a given visnum"""
#     adata.obs["Pred_label"] = pred_sp[:, visnum]
#     # vmin = 0
#     # vmax = np.amax(pred_sp)

#     sc.pl.spatial(
#         adata,
#         img_key="hires",
#         color="Pred_label",
#         palette="Set1",
#         size=1.5,
#         legend_loc=None,
#         title=f"{sc_sub_dict[visnum]}",
#         spot_size=100,
#         show=False,
#         # vmin=vmin,
#         # vmax=vmax,
#         ax=ax,
#     )


In [16]:
# def plot_cell_layers(df):

#     layer_idx = df["spatialLIBD"].unique().astype(str)
#     samples = df["sample_id"].unique()
#     layer_idx.sort()
#     fig, ax = plt.subplots(
#         nrows=1,
#         ncols=len(samples),
#         figsize=(5 * len(samples), 5),
#         squeeze=False,
#         constrained_layout=True,
#     )

#     for idx, sample in enumerate(samples):
#         cells_of_samples = df[df["sample_id"] == sample]
#         for index in layer_idx:
#             cells_of_layer = cells_of_samples[cells_of_samples["spatialLIBD"] == index]
#             ax.flat[idx].scatter(
#                 cells_of_layer["X"], -cells_of_layer["Y"], label=index, s=17, marker="o"
#             )

#         ax.flat[idx].axis("equal")
#         ax.flat[idx].set_xticks([])
#         ax.flat[idx].set_yticks([])
#         ax.flat[idx].set_title(sample)

#     plt.legend()
#     plt.show()


In [17]:
# def plot_roc(visnum, adata, pred_sp, name, ax=None):
#     """Plot ROC for a given visnum"""

#     def layer_to_layer_number(x):
#         for char in x:
#             if char.isdigit():
#                 if int(char) in Ex_to_L_d[num_to_ex_d[visnum]]:
#                     return 1
#         return 0

#     y_pred = pred_sp[:, visnum]
#     y_true = adata.obs["spatialLIBD"].map(layer_to_layer_number).fillna(0)
#     # print(y_true)
#     # print(y_true.isna().sum())
#     RocCurveDisplay.from_predictions(y_true=y_true, y_pred=y_pred, name=name, ax=ax)


In [18]:
# # plot_cell_layers(adata_spatialLIBD_151673.obs)

# fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5, 5), constrained_layout=True)

# sc.pl.spatial(
#     adata_spatialLIBD_d[SAMPLE_ID_N],
#     img_key=None,
#     color="spatialLIBD",
#     palette="Accent_r",
#     size=1.5,
#     title=SAMPLE_ID_N,
#     # legend_loc = 4,
#     spot_size=100,
#     show=False,
#     ax=ax,
# )

# ax.axis("equal")
# ax.set_xlabel("")
# ax.set_ylabel("")

# fig.show()


In [19]:
# fig, ax = plt.subplots(2, 5, figsize=(20, 8), constrained_layout=True)

# for i, num in enumerate(numlist):
#     plot_cellfraction(
#         num, adata_spatialLIBD_d[SAMPLE_ID_N], pred_sp_d[SAMPLE_ID_N], ax.flat[i]
#     )
#     ax.flat[i].axis("equal")
#     ax.flat[i].set_xlabel("")
#     ax.flat[i].set_ylabel("")

# fig.show()
# # plt.close()

# fig, ax = plt.subplots(
#     2, 5, figsize=(20, 8), constrained_layout=True, sharex=True, sharey=True
# )

# for i, num in enumerate(numlist):
#     plot_roc(
#         num,
#         adata_spatialLIBD_d[SAMPLE_ID_N],
#         pred_sp_d[SAMPLE_ID_N],
#         "CellDART",
#         ax.flat[i],
#     )
#     plot_roc(
#         num,
#         adata_spatialLIBD_d[SAMPLE_ID_N],
#         pred_sp_noda_d[SAMPLE_ID_N],
#         "NN_wo_da",
#         ax.flat[i],
#     )
#     ax.flat[i].plot([0, 1], [0, 1], transform=ax.flat[i].transAxes, ls="--", color="k")
#     ax.flat[i].set_aspect("equal")
#     ax.flat[i].set_xlim([0, 1])
#     ax.flat[i].set_ylim([0, 1])

#     ax.flat[i].set_title(f"{sc_sub_dict[num]}")

#     if i >= len(numlist) - 5:
#         ax.flat[i].set_xlabel("FPR")
#     else:
#         ax.flat[i].set_xlabel("")
#     if i % 5 == 0:
#         ax.flat[i].set_ylabel("TPR")
#     else:
#         ax.flat[i].set_ylabel("")

# fig.show()
# # plt.close()


- cf. Prediction of Mixture (pseudospots)


In [20]:
# if TRAIN_USING_ALL_ST_SAMPLES:
#     pred_mix = clssmodel.predict(sc_mix_test_s)
# else:
#     pred_mix = clssmodel_d[SAMPLE_ID_N].predict(sc_mix_test_s)


# cell_type_nums = sc_sub_dict.keys()
# nrows = ceil(len(cell_type_nums) / 5)

# line_kws = {"color": "tab:orange"}
# scatter_kws = {"s": 5}

# props = dict(facecolor="w", alpha=0.5)

# fig, ax = plt.subplots(
#     nrows,
#     5,
#     figsize=(20, 4 * nrows),
#     constrained_layout=True,
#     sharex=False,
#     sharey=True,
# )
# for i, num in enumerate(cell_type_nums):
#     sns.regplot(
#         x=pred_mix[:, num],
#         y=lab_mix_test[:, num],
#         line_kws=line_kws,
#         scatter_kws=scatter_kws,
#         ax=ax.flat[i],
#     ).set_title(sc_sub_dict[num])
#     ax.flat[i].set_aspect("equal")

#     ax.flat[i].set_xlabel("Predicted Proportion")
#     if i % 5 == 0:
#         ax.flat[i].set_ylabel("True Proportion")
#     else:
#         ax.flat[i].set_ylabel("")
#     ax.flat[i].set_xlim([0, 1])
#     ax.flat[i].set_ylim([0, 1])

#     textstr = f"MSE: {mean_squared_error(pred_mix[:,num], lab_mix_test[:,num]):.5f}"

#     # place a text box in upper left in axes coords
#     ax.flat[i].text(
#         0.95,
#         0.05,
#         textstr,
#         transform=ax.flat[i].transAxes,
#         verticalalignment="bottom",
#         horizontalalignment="right",
#         bbox=props,
#     )

# for i in range(len(cell_type_nums), nrows * 5):
#     ax.flat[i].axis("off")

# plt.show()


In [21]:
# print(
#     "\n".join(
#         f"{m.__name__} {m.__version__}"
#         for m in globals().values()
#         if getattr(m, "__version__", None)
#     )
# )
