# CellDART Example Code: mouse brain 
## (10x Visium of anterior mouse brain + scRNA-seq data of mouse brain)

In [1]:
import glob
import os
import pickle
from math import ceil

import anndata as ad
import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import scanpy as sc
import seaborn as sns
import scipy.stats as ss

from sklearn.model_selection import train_test_split
from sklearn.metrics import RocCurveDisplay
from sklearn.metrics import mean_squared_error
from sklearn import metrics

import tensorflow as tf  # TensorFlow registers PluggableDevices here.

from CellDART import da_cellfraction
from CellDART.utils import random_mix

from src.utils.data_loading import load_spatial, load_sc, get_selected_dir

from tqdm.autonotebook import tqdm


  from tqdm.autonotebook import tqdm


In [2]:
tf.config.list_physical_devices()


[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [3]:
TRAIN_USING_ALL_ST_SAMPLES = False
N_MARKERS = 20
ALL_GENES = False

ST_SPLIT = False
N_SPOTS = 20000
N_MIX = 8
SCALER_NAME = "celldart"


SAMPLE_ID_N = "151673"
INITIAL_TRAIN_EPOCHS = 10


BATCH_SIZE = 512
ALPHA = 0.6
ALPHA_LR = 5
N_ITER = 3000

DATA_DIR = "../AGrEDA/data"
# DATA_DIR = "./data_combine/"

BOOTSTRAP = False
BOOTSTRAP_ROUNDS = 10
BOOTSTRAP_ALPHAS = [0.6, 1 / 0.6]

MODEL_NAME = "CellDART"
MODEL_VERSION = "V1"


In [4]:
model_folder = os.path.join("model", MODEL_NAME, MODEL_VERSION)

if not os.path.isdir(model_folder):
    os.makedirs(model_folder)
    print(model_folder)


## 1. Data load
### load scanpy data - 10x datasets

In [5]:
# Load spatial data
mat_sp_d, mat_sp_train, st_sample_id_l = load_spatial(
    get_selected_dir(DATA_DIR, N_MARKERS, ALL_GENES),
    SCALER_NAME,
    train_using_all_st_samples=TRAIN_USING_ALL_ST_SAMPLES,
    st_split=ST_SPLIT,
)

# Load sc data
sc_mix_d, lab_mix_d, sc_sub_dict, sc_sub_dict2 = load_sc(
    get_selected_dir(DATA_DIR, N_MARKERS, ALL_GENES),
    SCALER_NAME,
    n_mix=N_MIX,
    n_spots=N_SPOTS,
)


In [6]:
advtrain_folder = os.path.join(model_folder, "advtrain")
pretrain_folder = os.path.join(model_folder, "pretrain")
if not os.path.isdir(advtrain_folder):
    os.makedirs(advtrain_folder)
if not os.path.isdir(pretrain_folder):
    os.makedirs(pretrain_folder)


In [7]:
# st_sample_id_l = [SAMPLE_ID_N]


In [8]:
if TRAIN_USING_ALL_ST_SAMPLES:
    print(f"Adversarial training for all ST slides")
    embs, clssmodel, clssmodel_noda = da_cellfraction.train(
        sc_mix_d["train"],
        lab_mix_d["train"],
        mat_sp_train_s,
        alpha=ALPHA,
        alpha_lr=5,
        emb_dim=64,
        batch_size=BATCH_SIZE,
        n_iterations=N_ITER,
        initial_train=True,
        initial_train_epochs=INITIAL_TRAIN_EPOCHS,
    )
elif BOOTSTRAP:
    pred_sp_boostrap_d = {}

    outer = tqdm(total=len(BOOTSTRAP_ALPHAS), desc="Alphas", position=0)
    inner1 = tqdm(total=len(st_sample_id_l), desc=f"Sample", position=1)
    inner2 = tqdm(total=BOOTSTRAP_ROUNDS, desc=f"Bootstrap #", position=2)
    for alpha in BOOTSTRAP_ALPHAS:

        inner1.refresh()  # force print final state
        inner1.reset()  # reuse bar

        pred_sp_boostrap_d[alpha] = {}

        for sample_id in st_sample_id_l:
            inner2.refresh()  # force print final state
            inner2.reset()  # reuse bar

            pred_sp_boostrap_d[alpha][sample_id] = []
            for i in range(BOOTSTRAP_ROUNDS):
                print(f"Adversarial training for ST slide {sample_id}: ")
                embs, clssmodel, _ = da_cellfraction.train(
                    sc_mix_d["train"],
                    lab_mix_d["train"],
                    mat_sp_d["train"][sample_id],
                    alpha=alpha,
                    alpha_lr=5,
                    emb_dim=64,
                    batch_size=BATCH_SIZE,
                    n_iterations=N_ITER,
                    initial_train=True,
                    initial_train_epochs=10,
                    seed=i,
                )

                pred_sp_boostrap_d[alpha][sample_id].append(
                    clssmodel.predict(mat_sp_d["train"][sample_id])
                )
                inner2.update(1)
            inner1.update(1)
        outer.update(1)

else:
    # embs_d, clssmodel_d, clssmodel_noda_d = {}, {}, {}
    for sample_id in st_sample_id_l:
        print(f"Adversarial training for ST slide {sample_id}: ")
        if not os.path.isdir(os.path.join(advtrain_folder, sample_id)):
            os.makedirs(os.path.join(advtrain_folder, sample_id))
        if not os.path.isdir(os.path.join(pretrain_folder, sample_id)):
            os.makedirs(os.path.join(pretrain_folder, sample_id))
        embs, embs_noda, clssmodel, clssmodel_noda = da_cellfraction.train(
            sc_mix_d["train"],
            lab_mix_d["train"],
            mat_sp_d[sample_id]["train"],
            alpha=ALPHA,
            alpha_lr=5,
            emb_dim=64,
            batch_size=BATCH_SIZE,
            n_iterations=N_ITER,
            initial_train=True,
            initial_train_epochs=10,
            seed=int(sample_id),
        )
        # embs_d[sample_id] = embs
        # clssmodel_d[sample_id] = clssmodel
        # clssmodel_noda_d[sample_id] = clssmodel_noda

        # print(keras.losses.KLDivergence()(lab_mix_d["train"], clssmodel_noda.predict(sc_mix_d["train"])).numpy())
        # print(keras.losses.KLDivergence()(lab_mix_d["train"], clssmodel.predict(sc_mix_d["train"])).numpy())
        # Save model
        clssmodel_noda.save(os.path.join(pretrain_folder, sample_id, "final_model"))
        clssmodel.save(os.path.join(advtrain_folder, sample_id, "final_model"))

        embs_noda.save(os.path.join(pretrain_folder, sample_id, "embs"))
        embs.save(os.path.join(advtrain_folder, sample_id, "embs"))


Adversarial training for ST slide 151507: 
Train on 20000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
initial_train_done


  updates = self.state_updates


0.609755919265747
0.609755919265747
Iteration 99, source loss =  3.951, discriminator acc = 0.180
Iteration 199, source loss =  1.886, discriminator acc = 0.457
Iteration 299, source loss =  1.424, discriminator acc = 0.928
Iteration 399, source loss =  1.558, discriminator acc = 0.845
Iteration 499, source loss =  1.367, discriminator acc = 0.988
Iteration 599, source loss =  1.208, discriminator acc = 0.155
Iteration 699, source loss =  1.221, discriminator acc = 0.997
Iteration 799, source loss =  1.072, discriminator acc = 0.711
Iteration 899, source loss =  1.043, discriminator acc = 0.918
Iteration 999, source loss =  1.088, discriminator acc = 1.000
Iteration 1099, source loss =  1.077, discriminator acc = 0.143
Iteration 1199, source loss =  0.838, discriminator acc = 1.000
Iteration 1299, source loss =  0.886, discriminator acc = 0.153
Iteration 1399, source loss =  0.834, discriminator acc = 0.656
Iteration 1499, source loss =  0.757, discriminator acc = 0.993
Iteration 1599,

  updates = self.state_updates


0.6339941244125367
0.6339941244125367
Iteration 99, source loss =  4.053, discriminator acc = 0.067
Iteration 199, source loss =  2.089, discriminator acc = 0.856
Iteration 299, source loss =  2.219, discriminator acc = 0.821
Iteration 399, source loss =  1.455, discriminator acc = 0.856
Iteration 499, source loss =  1.599, discriminator acc = 0.843
Iteration 599, source loss =  1.428, discriminator acc = 0.835
Iteration 699, source loss =  1.335, discriminator acc = 0.836
Iteration 799, source loss =  1.223, discriminator acc = 0.879
Iteration 899, source loss =  1.190, discriminator acc = 0.954
Iteration 999, source loss =  1.416, discriminator acc = 0.973
Iteration 1099, source loss =  1.286, discriminator acc = 0.942
Iteration 1199, source loss =  1.158, discriminator acc = 0.908
Iteration 1299, source loss =  1.065, discriminator acc = 0.437
Iteration 1399, source loss =  1.039, discriminator acc = 0.764
Iteration 1499, source loss =  0.913, discriminator acc = 0.866
Iteration 159

  updates = self.state_updates


0.6195967520713807
0.6195967520713807
Iteration 99, source loss =  3.593, discriminator acc = 0.056
Iteration 199, source loss =  1.802, discriminator acc = 0.818
Iteration 299, source loss =  1.440, discriminator acc = 0.812
Iteration 399, source loss =  1.287, discriminator acc = 0.482
Iteration 499, source loss =  1.646, discriminator acc = 0.187
Iteration 599, source loss =  1.290, discriminator acc = 0.193
Iteration 699, source loss =  1.297, discriminator acc = 0.813
Iteration 799, source loss =  1.307, discriminator acc = 0.193
Iteration 899, source loss =  1.266, discriminator acc = 0.922
Iteration 999, source loss =  1.172, discriminator acc = 0.475
Iteration 1099, source loss =  1.176, discriminator acc = 0.823
Iteration 1199, source loss =  1.159, discriminator acc = 0.614
Iteration 1299, source loss =  1.095, discriminator acc = 1.000
Iteration 1399, source loss =  1.095, discriminator acc = 0.816
Iteration 1499, source loss =  0.993, discriminator acc = 0.992
Iteration 159

  updates = self.state_updates


0.663599541759491
0.663599541759491
Iteration 99, source loss =  3.144, discriminator acc = 0.698
Iteration 199, source loss =  1.824, discriminator acc = 0.876
Iteration 299, source loss =  1.580, discriminator acc = 0.824
Iteration 399, source loss =  1.783, discriminator acc = 0.281
Iteration 499, source loss =  1.418, discriminator acc = 0.992
Iteration 599, source loss =  1.305, discriminator acc = 0.871
Iteration 699, source loss =  1.148, discriminator acc = 0.789
Iteration 799, source loss =  1.150, discriminator acc = 0.959
Iteration 899, source loss =  0.949, discriminator acc = 0.890
Iteration 999, source loss =  0.884, discriminator acc = 0.266
Iteration 1099, source loss =  0.803, discriminator acc = 0.328
Iteration 1199, source loss =  0.766, discriminator acc = 0.692
Iteration 1299, source loss =  0.728, discriminator acc = 0.593
Iteration 1399, source loss =  0.743, discriminator acc = 0.892
Iteration 1499, source loss =  0.778, discriminator acc = 0.691
Iteration 1599,

  updates = self.state_updates


0.6345004971504211
0.6345004971504211
Iteration 99, source loss =  3.426, discriminator acc = 0.155
Iteration 199, source loss =  2.907, discriminator acc = 0.714
Iteration 299, source loss =  2.119, discriminator acc = 0.993
Iteration 399, source loss =  1.660, discriminator acc = 0.208
Iteration 499, source loss =  1.162, discriminator acc = 0.879
Iteration 599, source loss =  1.209, discriminator acc = 0.979
Iteration 699, source loss =  1.436, discriminator acc = 0.182
Iteration 799, source loss =  1.229, discriminator acc = 0.939
Iteration 899, source loss =  1.207, discriminator acc = 0.931
Iteration 999, source loss =  1.056, discriminator acc = 0.618
Iteration 1099, source loss =  0.920, discriminator acc = 0.977
Iteration 1199, source loss =  0.984, discriminator acc = 0.185
Iteration 1299, source loss =  0.840, discriminator acc = 0.894
Iteration 1399, source loss =  0.768, discriminator acc = 0.953
Iteration 1499, source loss =  0.783, discriminator acc = 0.205
Iteration 159

  updates = self.state_updates


0.6091904026985169
0.6091904026985169
Iteration 99, source loss =  3.523, discriminator acc = 0.226
Iteration 199, source loss =  2.605, discriminator acc = 0.594
Iteration 299, source loss =  1.481, discriminator acc = 0.145
Iteration 399, source loss =  1.351, discriminator acc = 0.156
Iteration 499, source loss =  1.493, discriminator acc = 0.149
Iteration 599, source loss =  1.128, discriminator acc = 0.430
Iteration 699, source loss =  1.302, discriminator acc = 0.505
Iteration 799, source loss =  1.295, discriminator acc = 0.148
Iteration 899, source loss =  1.190, discriminator acc = 0.247
Iteration 999, source loss =  1.125, discriminator acc = 0.169
Iteration 1099, source loss =  1.193, discriminator acc = 0.149
Iteration 1199, source loss =  1.061, discriminator acc = 0.963
Iteration 1299, source loss =  1.159, discriminator acc = 0.649
Iteration 1399, source loss =  1.167, discriminator acc = 0.898
Iteration 1499, source loss =  0.957, discriminator acc = 0.377
Iteration 159

  updates = self.state_updates


0.5910753908157349
0.5910753908157349
Iteration 99, source loss =  2.810, discriminator acc = 0.170
Iteration 199, source loss =  2.593, discriminator acc = 0.170
Iteration 299, source loss =  1.459, discriminator acc = 0.868
Iteration 399, source loss =  1.477, discriminator acc = 0.999
Iteration 499, source loss =  1.321, discriminator acc = 0.170
Iteration 599, source loss =  1.479, discriminator acc = 0.430
Iteration 699, source loss =  1.358, discriminator acc = 0.166
Iteration 799, source loss =  1.162, discriminator acc = 0.946
Iteration 899, source loss =  1.189, discriminator acc = 0.191
Iteration 999, source loss =  1.138, discriminator acc = 1.000
Iteration 1099, source loss =  1.057, discriminator acc = 0.632
Iteration 1199, source loss =  1.010, discriminator acc = 1.000
Iteration 1299, source loss =  1.020, discriminator acc = 0.469
Iteration 1399, source loss =  0.921, discriminator acc = 0.551
Iteration 1499, source loss =  1.045, discriminator acc = 0.898
Iteration 159

  updates = self.state_updates


0.5862456747531891
0.5862456747531891
Iteration 99, source loss =  2.423, discriminator acc = 0.072
Iteration 199, source loss =  1.445, discriminator acc = 0.151
Iteration 299, source loss =  1.468, discriminator acc = 0.938
Iteration 399, source loss =  1.513, discriminator acc = 0.497
Iteration 499, source loss =  1.373, discriminator acc = 0.293
Iteration 599, source loss =  1.457, discriminator acc = 0.385
Iteration 699, source loss =  1.257, discriminator acc = 0.979
Iteration 799, source loss =  1.091, discriminator acc = 0.829
Iteration 899, source loss =  1.090, discriminator acc = 0.860
Iteration 999, source loss =  1.225, discriminator acc = 0.435
Iteration 1099, source loss =  1.083, discriminator acc = 0.965
Iteration 1199, source loss =  1.100, discriminator acc = 0.788
Iteration 1299, source loss =  1.094, discriminator acc = 0.984
Iteration 1399, source loss =  1.074, discriminator acc = 0.166
Iteration 1499, source loss =  1.039, discriminator acc = 0.998
Iteration 159

  updates = self.state_updates


0.6047526482105255
0.6047526482105255
Iteration 99, source loss =  2.733, discriminator acc = 0.859
Iteration 199, source loss =  2.062, discriminator acc = 0.154
Iteration 299, source loss =  1.557, discriminator acc = 0.623
Iteration 399, source loss =  1.483, discriminator acc = 0.217
Iteration 499, source loss =  1.316, discriminator acc = 0.901
Iteration 599, source loss =  1.292, discriminator acc = 0.886
Iteration 699, source loss =  1.124, discriminator acc = 0.151
Iteration 799, source loss =  1.150, discriminator acc = 0.971
Iteration 899, source loss =  1.114, discriminator acc = 0.688
Iteration 999, source loss =  1.175, discriminator acc = 0.969
Iteration 1099, source loss =  1.258, discriminator acc = 0.158
Iteration 1199, source loss =  1.054, discriminator acc = 0.961
Iteration 1299, source loss =  1.001, discriminator acc = 0.381
Iteration 1399, source loss =  1.040, discriminator acc = 0.703
Iteration 1499, source loss =  0.956, discriminator acc = 0.795
Iteration 159

  updates = self.state_updates


0.6114996797084808
0.6114996797084808
Iteration 99, source loss =  4.454, discriminator acc = 0.155
Iteration 199, source loss =  2.480, discriminator acc = 0.155
Iteration 299, source loss =  1.531, discriminator acc = 0.856
Iteration 399, source loss =  1.431, discriminator acc = 0.950
Iteration 499, source loss =  1.413, discriminator acc = 0.876
Iteration 599, source loss =  1.577, discriminator acc = 0.156
Iteration 699, source loss =  1.419, discriminator acc = 0.845
Iteration 799, source loss =  1.152, discriminator acc = 0.809
Iteration 899, source loss =  1.275, discriminator acc = 0.963
Iteration 999, source loss =  1.035, discriminator acc = 0.652
Iteration 1099, source loss =  1.034, discriminator acc = 0.810
Iteration 1199, source loss =  1.031, discriminator acc = 0.338
Iteration 1299, source loss =  0.994, discriminator acc = 0.538
Iteration 1399, source loss =  0.819, discriminator acc = 0.967
Iteration 1499, source loss =  0.788, discriminator acc = 0.959
Iteration 159

  updates = self.state_updates


0.6360885352134704
0.6360885352134704
Iteration 99, source loss =  2.281, discriminator acc = 0.137
Iteration 199, source loss =  2.358, discriminator acc = 0.762
Iteration 299, source loss =  1.646, discriminator acc = 0.880
Iteration 399, source loss =  1.668, discriminator acc = 0.877
Iteration 499, source loss =  1.350, discriminator acc = 0.874
Iteration 599, source loss =  1.552, discriminator acc = 0.888
Iteration 699, source loss =  1.322, discriminator acc = 0.901
Iteration 799, source loss =  1.288, discriminator acc = 0.882
Iteration 899, source loss =  1.250, discriminator acc = 0.873
Iteration 999, source loss =  1.287, discriminator acc = 0.882
Iteration 1099, source loss =  1.243, discriminator acc = 0.995
Iteration 1199, source loss =  1.294, discriminator acc = 0.951
Iteration 1299, source loss =  1.455, discriminator acc = 0.498
Iteration 1399, source loss =  1.342, discriminator acc = 0.883
Iteration 1499, source loss =  1.047, discriminator acc = 0.859
Iteration 159

  updates = self.state_updates


0.6267935317993164
0.6267935317993164
Iteration 99, source loss =  2.441, discriminator acc = 0.847
Iteration 199, source loss =  1.537, discriminator acc = 0.865
Iteration 299, source loss =  1.552, discriminator acc = 0.855
Iteration 399, source loss =  1.503, discriminator acc = 0.871
Iteration 499, source loss =  1.373, discriminator acc = 0.883
Iteration 599, source loss =  1.590, discriminator acc = 0.873
Iteration 699, source loss =  1.297, discriminator acc = 0.885
Iteration 799, source loss =  1.325, discriminator acc = 0.875
Iteration 899, source loss =  1.265, discriminator acc = 0.878
Iteration 999, source loss =  1.179, discriminator acc = 0.887
Iteration 1099, source loss =  1.233, discriminator acc = 0.973
Iteration 1199, source loss =  1.316, discriminator acc = 0.530
Iteration 1299, source loss =  1.380, discriminator acc = 0.987
Iteration 1399, source loss =  1.063, discriminator acc = 0.959
Iteration 1499, source loss =  1.107, discriminator acc = 0.374
Iteration 159

In [9]:
# def jsd(y_true, y_pred):
#     return tf.keras.losses.kullback_leibler_divergence(y_true, y_pred)


In [10]:
# clssmodel.compile(
#         optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
#         loss={'mo': 'kld'},
#         metrics=['mae'],
#     )

# pred = clssmodel.evaluate(sc_mix_d["train"], lab_mix_d["train"])


In [11]:
# pred


In [12]:
# clssmodel_noda.compile(
#         optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
#         loss={'mo': 'kld'},
#         metrics=['mae'],
#     )

# clssmodel_noda.evaluate(sc_mix_d["train"], lab_mix_d["train"])


In [13]:
# confidence = 0.95
# if BOOTSTRAP:
#     for alpha in BOOTSTRAP_ALPHAS:
#         bootstrap_d = {}
#         for sample_id in st_sample_id_l:


#             bootstrap_d[sample_id] = {}
#             for i, num in enumerate(numlist):

#                 acc = [metrics.roc_auc_score(*gen_pred_true(num, adata_spatialLIBD_d[SAMPLE_ID_N], pred)[::-1]) for pred in pred_sp_boostrap_d[alpha][SAMPLE_ID_N]]
#                 test_mean = np.mean(acc)
#                 t_value = ss.t.ppf((1 + confidence) / 2.0, df=BOOTSTRAP_ROUNDS - 1)

#                 sd = np.std(acc, ddof=1)
#                 se = sd / np.sqrt(BOOTSTRAP_ROUNDS)

#                 ci_length = t_value * se

#                 ci_lower = test_mean - ci_length
#                 ci_upper = test_mean + ci_length

#                 bootstrap_d[sample_id][num_to_ex_d[num]] = (ci_lower, test_mean, ci_upper)


#         bootstrap_df = pd.DataFrame.from_dict(bootstrap_d)
#         display(bootstrap_df)
#         bootstrap_df.to_csv(os.path.join(results_folder, f'bootstrap_alpha{alpha}.csv'))


## 4. Predict cell fraction of spots and visualization

In [14]:
# pred_sp_d, pred_sp_noda_d = {}, {}
# if TRAIN_USING_ALL_ST_SAMPLES:
#     for sample_id in st_sample_id_l:
#         pred_sp_d[sample_id] = clssmodel.predict(mat_sp_test_s_d[sample_id])
#         pred_sp_noda_d[sample_id] = clssmodel_noda.predict(mat_sp_test_s_d[sample_id])
# else:
#     for sample_id in st_sample_id_l:
#         pred_sp_d[sample_id] = clssmodel_d[sample_id].predict(
#             mat_sp_test_s_d[sample_id]
#         )
#         pred_sp_noda_d[sample_id] = clssmodel_noda_d[sample_id].predict(
#             mat_sp_test_s_d[sample_id]
#         )


In [15]:
# def plot_cellfraction(visnum, adata, pred_sp, ax=None):
#     """Plot predicted cell fraction for a given visnum"""
#     adata.obs["Pred_label"] = pred_sp[:, visnum]
#     # vmin = 0
#     # vmax = np.amax(pred_sp)

#     sc.pl.spatial(
#         adata,
#         img_key="hires",
#         color="Pred_label",
#         palette="Set1",
#         size=1.5,
#         legend_loc=None,
#         title=f"{sc_sub_dict[visnum]}",
#         spot_size=100,
#         show=False,
#         # vmin=vmin,
#         # vmax=vmax,
#         ax=ax,
#     )


In [16]:
# def plot_cell_layers(df):

#     layer_idx = df["spatialLIBD"].unique().astype(str)
#     samples = df["sample_id"].unique()
#     layer_idx.sort()
#     fig, ax = plt.subplots(
#         nrows=1,
#         ncols=len(samples),
#         figsize=(5 * len(samples), 5),
#         squeeze=False,
#         constrained_layout=True,
#     )

#     for idx, sample in enumerate(samples):
#         cells_of_samples = df[df["sample_id"] == sample]
#         for index in layer_idx:
#             cells_of_layer = cells_of_samples[cells_of_samples["spatialLIBD"] == index]
#             ax.flat[idx].scatter(
#                 cells_of_layer["X"], -cells_of_layer["Y"], label=index, s=17, marker="o"
#             )

#         ax.flat[idx].axis("equal")
#         ax.flat[idx].set_xticks([])
#         ax.flat[idx].set_yticks([])
#         ax.flat[idx].set_title(sample)

#     plt.legend()
#     plt.show()


In [17]:
# def plot_roc(visnum, adata, pred_sp, name, ax=None):
#     """Plot ROC for a given visnum"""

#     def layer_to_layer_number(x):
#         for char in x:
#             if char.isdigit():
#                 if int(char) in Ex_to_L_d[num_to_ex_d[visnum]]:
#                     return 1
#         return 0

#     y_pred = pred_sp[:, visnum]
#     y_true = adata.obs["spatialLIBD"].map(layer_to_layer_number).fillna(0)
#     # print(y_true)
#     # print(y_true.isna().sum())
#     RocCurveDisplay.from_predictions(y_true=y_true, y_pred=y_pred, name=name, ax=ax)


In [18]:
# # plot_cell_layers(adata_spatialLIBD_151673.obs)

# fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5, 5), constrained_layout=True)

# sc.pl.spatial(
#     adata_spatialLIBD_d[SAMPLE_ID_N],
#     img_key=None,
#     color="spatialLIBD",
#     palette="Accent_r",
#     size=1.5,
#     title=SAMPLE_ID_N,
#     # legend_loc = 4,
#     spot_size=100,
#     show=False,
#     ax=ax,
# )

# ax.axis("equal")
# ax.set_xlabel("")
# ax.set_ylabel("")

# fig.show()


In [19]:
# fig, ax = plt.subplots(2, 5, figsize=(20, 8), constrained_layout=True)

# for i, num in enumerate(numlist):
#     plot_cellfraction(
#         num, adata_spatialLIBD_d[SAMPLE_ID_N], pred_sp_d[SAMPLE_ID_N], ax.flat[i]
#     )
#     ax.flat[i].axis("equal")
#     ax.flat[i].set_xlabel("")
#     ax.flat[i].set_ylabel("")

# fig.show()
# # plt.close()

# fig, ax = plt.subplots(
#     2, 5, figsize=(20, 8), constrained_layout=True, sharex=True, sharey=True
# )

# for i, num in enumerate(numlist):
#     plot_roc(
#         num,
#         adata_spatialLIBD_d[SAMPLE_ID_N],
#         pred_sp_d[SAMPLE_ID_N],
#         "CellDART",
#         ax.flat[i],
#     )
#     plot_roc(
#         num,
#         adata_spatialLIBD_d[SAMPLE_ID_N],
#         pred_sp_noda_d[SAMPLE_ID_N],
#         "NN_wo_da",
#         ax.flat[i],
#     )
#     ax.flat[i].plot([0, 1], [0, 1], transform=ax.flat[i].transAxes, ls="--", color="k")
#     ax.flat[i].set_aspect("equal")
#     ax.flat[i].set_xlim([0, 1])
#     ax.flat[i].set_ylim([0, 1])

#     ax.flat[i].set_title(f"{sc_sub_dict[num]}")

#     if i >= len(numlist) - 5:
#         ax.flat[i].set_xlabel("FPR")
#     else:
#         ax.flat[i].set_xlabel("")
#     if i % 5 == 0:
#         ax.flat[i].set_ylabel("TPR")
#     else:
#         ax.flat[i].set_ylabel("")

# fig.show()
# # plt.close()


- cf. Prediction of Mixture (pseudospots)


In [20]:
# if TRAIN_USING_ALL_ST_SAMPLES:
#     pred_mix = clssmodel.predict(sc_mix_test_s)
# else:
#     pred_mix = clssmodel_d[SAMPLE_ID_N].predict(sc_mix_test_s)


# cell_type_nums = sc_sub_dict.keys()
# nrows = ceil(len(cell_type_nums) / 5)

# line_kws = {"color": "tab:orange"}
# scatter_kws = {"s": 5}

# props = dict(facecolor="w", alpha=0.5)

# fig, ax = plt.subplots(
#     nrows,
#     5,
#     figsize=(20, 4 * nrows),
#     constrained_layout=True,
#     sharex=False,
#     sharey=True,
# )
# for i, num in enumerate(cell_type_nums):
#     sns.regplot(
#         x=pred_mix[:, num],
#         y=lab_mix_test[:, num],
#         line_kws=line_kws,
#         scatter_kws=scatter_kws,
#         ax=ax.flat[i],
#     ).set_title(sc_sub_dict[num])
#     ax.flat[i].set_aspect("equal")

#     ax.flat[i].set_xlabel("Predicted Proportion")
#     if i % 5 == 0:
#         ax.flat[i].set_ylabel("True Proportion")
#     else:
#         ax.flat[i].set_ylabel("")
#     ax.flat[i].set_xlim([0, 1])
#     ax.flat[i].set_ylim([0, 1])

#     textstr = f"MSE: {mean_squared_error(pred_mix[:,num], lab_mix_test[:,num]):.5f}"

#     # place a text box in upper left in axes coords
#     ax.flat[i].text(
#         0.95,
#         0.05,
#         textstr,
#         transform=ax.flat[i].transAxes,
#         verticalalignment="bottom",
#         horizontalalignment="right",
#         bbox=props,
#     )

# for i in range(len(cell_type_nums), nrows * 5):
#     ax.flat[i].axis("off")

# plt.show()


In [21]:
# print(
#     "\n".join(
#         f"{m.__name__} {m.__version__}"
#         for m in globals().values()
#         if getattr(m, "__version__", None)
#     )
# )
