# CellDART Example Code: mouse brain 
## (10x Visium of anterior mouse brain + scRNA-seq data of mouse brain)

In [1]:
import glob
import os
import pickle
from math import ceil

import anndata as ad
import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import scanpy as sc
import seaborn as sns
import scipy.stats as ss

from sklearn.model_selection import train_test_split
from sklearn.metrics import RocCurveDisplay
from sklearn.metrics import mean_squared_error
from sklearn import metrics

from CellDART import da_cellfraction
from CellDART.utils import random_mix

from src.utils.data_loading import load_spatial, load_sc, get_selected_dir

from tqdm.autonotebook import tqdm


  from tqdm.autonotebook import tqdm


In [2]:
import tensorflow as tf   # TensorFlow registers PluggableDevices here.
tf.config.list_physical_devices() 

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [3]:
TRAIN_USING_ALL_ST_SAMPLES = False
N_MARKERS = 20
ALL_GENES = False

ST_SPLIT = False
N_SPOTS = 20000
N_MIX = 8
SCALER_NAME = "celldart"


SAMPLE_ID_N = "151673"
INITIAL_TRAIN_EPOCHS = 10


BATCH_SIZE = 512
ALPHA = 0.6
ALPHA_LR = 5
N_ITER = 3000

DATA_DIR = "../AGrEDA/data"
# DATA_DIR = "./data_combine/"

BOOTSTRAP = False
BOOTSTRAP_ROUNDS = 10
BOOTSTRAP_ALPHAS = [0.6, 1/0.6]

MODEL_NAME = 'CellDART'
MODEL_VERSION = "V1"

In [4]:
model_folder = os.path.join("model", MODEL_NAME, MODEL_VERSION)

if not os.path.isdir(model_folder):
    os.makedirs(model_folder)
    print(model_folder)


## 1. Data load
### load scanpy data - 10x datasets

In [5]:
# Load spatial data
mat_sp_d, mat_sp_train, st_sample_id_l = load_spatial(
    get_selected_dir(DATA_DIR, N_MARKERS, ALL_GENES),
    SCALER_NAME,
    train_using_all_st_samples=TRAIN_USING_ALL_ST_SAMPLES,
    st_split=ST_SPLIT,
)

# Load sc data
sc_mix_d, lab_mix_d, sc_sub_dict, sc_sub_dict2 = load_sc(
    get_selected_dir(DATA_DIR, N_MARKERS, ALL_GENES),
    SCALER_NAME,
    n_mix=N_MIX,
    n_spots=N_SPOTS,
)


In [6]:
advtrain_folder = os.path.join(model_folder, "advtrain")
pretrain_folder = os.path.join(model_folder, "pretrain")
if not os.path.isdir(advtrain_folder):
    os.makedirs(advtrain_folder)
if not os.path.isdir(pretrain_folder):
    os.makedirs(pretrain_folder)

In [7]:
if TRAIN_USING_ALL_ST_SAMPLES:
    print(f"Adversarial training for all ST slides")
    embs, clssmodel, clssmodel_noda = da_cellfraction.train(
        sc_mix_d["train"],
        lab_mix_d["train"],
        mat_sp_train_s,
        alpha=ALPHA,
        alpha_lr=5,
        emb_dim=64,
        batch_size=BATCH_SIZE,
        n_iterations=N_ITER,
        initial_train=True,
        initial_train_epochs=INITIAL_TRAIN_EPOCHS,
    )
elif BOOTSTRAP:
    pred_sp_boostrap_d = {}
    
    outer = tqdm(total=len(BOOTSTRAP_ALPHAS), desc="Alphas", position=0)
    inner1 = tqdm(total=len(st_sample_id_l), desc=f"Sample", position=1)
    inner2 = tqdm(total=BOOTSTRAP_ROUNDS, desc=f"Bootstrap #", position=2)
    for alpha in BOOTSTRAP_ALPHAS:
        
        inner1.refresh()  # force print final state
        inner1.reset()  # reuse bar

        
        pred_sp_boostrap_d[alpha] = {}
        

        for sample_id in st_sample_id_l:
            inner2.refresh()  # force print final state
            inner2.reset()  # reuse bar
            
            pred_sp_boostrap_d[alpha][sample_id] = []
            for i in range(BOOTSTRAP_ROUNDS):
                print(f"Adversarial training for ST slide {sample_id}: ")
                embs, clssmodel, _ = da_cellfraction.train(
                    sc_mix_d["train"],
                    lab_mix_d["train"],
                    mat_sp_d["train"][sample_id],
                    alpha=alpha,
                    alpha_lr=5,
                    emb_dim=64,
                    batch_size=BATCH_SIZE,
                    n_iterations=N_ITER,
                    initial_train=True,
                    initial_train_epochs=10,
                    seed = i,
                )

                pred_sp_boostrap_d[alpha][sample_id].append(clssmodel.predict(
                     mat_sp_d["train"][sample_id]
                ))
                inner2.update(1)
            inner1.update(1)
        outer.update(1)

else:
    # embs_d, clssmodel_d, clssmodel_noda_d = {}, {}, {}
    for sample_id in st_sample_id_l:
        print(f"Adversarial training for ST slide {sample_id}: ")
        if not os.path.isdir(os.path.join(advtrain_folder, sample_id)):
            os.makedirs(os.path.join(advtrain_folder, sample_id))
        if not os.path.isdir(os.path.join(pretrain_folder, sample_id)):
            os.makedirs(os.path.join(pretrain_folder, sample_id))
        embs, embs_noda, clssmodel, clssmodel_noda = da_cellfraction.train(
            sc_mix_d["train"],
            lab_mix_d["train"],
            mat_sp_d[sample_id]["train"],
            alpha=ALPHA,
            alpha_lr=5,
            emb_dim=64,
            batch_size=BATCH_SIZE,
            n_iterations=N_ITER,
            initial_train=True,
            initial_train_epochs=10,
            seed=int(sample_id)
        )
        # embs_d[sample_id] = embs
        # clssmodel_d[sample_id] = clssmodel
        # clssmodel_noda_d[sample_id] = clssmodel_noda

        # Save model
        clssmodel_noda.save(os.path.join(pretrain_folder, sample_id, "final_model"))
        clssmodel.save(os.path.join(advtrain_folder, sample_id, "final_model"))

        embs_noda.save(os.path.join(pretrain_folder, sample_id, "embs"))
        embs.save(os.path.join(advtrain_folder, sample_id, "embs"))


Adversarial training for ST slide 151507: 
Train on 20000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
initial_train_done


  updates = self.state_updates


Iteration 99, source loss =  3.950, discriminator acc = 0.180
Iteration 199, source loss =  1.888, discriminator acc = 0.454
Iteration 299, source loss =  1.431, discriminator acc = 0.927
Iteration 399, source loss =  1.555, discriminator acc = 0.878
Iteration 499, source loss =  1.389, discriminator acc = 0.992
Iteration 599, source loss =  1.242, discriminator acc = 0.153
Iteration 699, source loss =  1.110, discriminator acc = 1.000
Iteration 799, source loss =  1.088, discriminator acc = 0.632
Iteration 899, source loss =  1.090, discriminator acc = 0.996
Iteration 999, source loss =  1.192, discriminator acc = 0.124
Iteration 1099, source loss =  1.054, discriminator acc = 0.292
Iteration 1199, source loss =  0.921, discriminator acc = 0.905
Iteration 1299, source loss =  0.993, discriminator acc = 0.182
Iteration 1399, source loss =  0.926, discriminator acc = 0.999
Iteration 1499, source loss =  1.067, discriminator acc = 0.151
Iteration 1599, source loss =  0.919, discriminator

  updates = self.state_updates


Iteration 99, source loss =  4.053, discriminator acc = 0.066
Iteration 199, source loss =  2.089, discriminator acc = 0.855
Iteration 299, source loss =  2.219, discriminator acc = 0.821
Iteration 399, source loss =  1.457, discriminator acc = 0.856
Iteration 499, source loss =  1.599, discriminator acc = 0.843
Iteration 599, source loss =  1.429, discriminator acc = 0.835
Iteration 699, source loss =  1.336, discriminator acc = 0.836
Iteration 799, source loss =  1.223, discriminator acc = 0.880
Iteration 899, source loss =  1.189, discriminator acc = 0.957
Iteration 999, source loss =  1.432, discriminator acc = 0.966
Iteration 1099, source loss =  1.277, discriminator acc = 0.949
Iteration 1199, source loss =  1.145, discriminator acc = 0.921
Iteration 1299, source loss =  1.066, discriminator acc = 0.402
Iteration 1399, source loss =  1.041, discriminator acc = 0.710
Iteration 1499, source loss =  0.906, discriminator acc = 0.856
Iteration 1599, source loss =  0.861, discriminator

  updates = self.state_updates


Iteration 99, source loss =  3.593, discriminator acc = 0.056
Iteration 199, source loss =  1.802, discriminator acc = 0.818
Iteration 299, source loss =  1.440, discriminator acc = 0.812
Iteration 399, source loss =  1.289, discriminator acc = 0.484
Iteration 499, source loss =  1.650, discriminator acc = 0.187
Iteration 599, source loss =  1.273, discriminator acc = 0.194
Iteration 699, source loss =  1.317, discriminator acc = 0.779
Iteration 799, source loss =  1.301, discriminator acc = 0.194
Iteration 899, source loss =  1.265, discriminator acc = 0.954
Iteration 999, source loss =  1.175, discriminator acc = 0.448
Iteration 1099, source loss =  1.186, discriminator acc = 0.864
Iteration 1199, source loss =  1.151, discriminator acc = 0.632
Iteration 1299, source loss =  1.108, discriminator acc = 1.000
Iteration 1399, source loss =  1.094, discriminator acc = 0.814
Iteration 1499, source loss =  0.988, discriminator acc = 0.996
Iteration 1599, source loss =  1.090, discriminator

  updates = self.state_updates


Iteration 99, source loss =  3.144, discriminator acc = 0.698
Iteration 199, source loss =  1.824, discriminator acc = 0.876
Iteration 299, source loss =  1.579, discriminator acc = 0.824
Iteration 399, source loss =  1.782, discriminator acc = 0.281
Iteration 499, source loss =  1.427, discriminator acc = 0.992
Iteration 599, source loss =  1.300, discriminator acc = 0.871
Iteration 699, source loss =  1.148, discriminator acc = 0.792
Iteration 799, source loss =  1.151, discriminator acc = 0.959
Iteration 899, source loss =  0.954, discriminator acc = 0.896
Iteration 999, source loss =  0.886, discriminator acc = 0.275
Iteration 1099, source loss =  0.804, discriminator acc = 0.347
Iteration 1199, source loss =  0.774, discriminator acc = 0.731
Iteration 1299, source loss =  0.718, discriminator acc = 0.821
Iteration 1399, source loss =  0.783, discriminator acc = 0.907
Iteration 1499, source loss =  0.763, discriminator acc = 0.481
Iteration 1599, source loss =  0.695, discriminator

  updates = self.state_updates


Iteration 99, source loss =  3.426, discriminator acc = 0.155
Iteration 199, source loss =  2.906, discriminator acc = 0.714
Iteration 299, source loss =  2.119, discriminator acc = 0.993
Iteration 399, source loss =  1.659, discriminator acc = 0.208
Iteration 499, source loss =  1.161, discriminator acc = 0.879
Iteration 599, source loss =  1.208, discriminator acc = 0.979
Iteration 699, source loss =  1.434, discriminator acc = 0.180
Iteration 799, source loss =  1.227, discriminator acc = 0.939
Iteration 899, source loss =  1.208, discriminator acc = 0.931
Iteration 999, source loss =  1.054, discriminator acc = 0.622
Iteration 1099, source loss =  0.916, discriminator acc = 0.977
Iteration 1199, source loss =  0.988, discriminator acc = 0.179
Iteration 1299, source loss =  0.841, discriminator acc = 0.880
Iteration 1399, source loss =  0.766, discriminator acc = 0.964
Iteration 1499, source loss =  0.783, discriminator acc = 0.188
Iteration 1599, source loss =  0.734, discriminator

  updates = self.state_updates


Iteration 99, source loss =  3.523, discriminator acc = 0.226
Iteration 199, source loss =  2.603, discriminator acc = 0.594
Iteration 299, source loss =  1.481, discriminator acc = 0.145
Iteration 399, source loss =  1.385, discriminator acc = 0.159
Iteration 499, source loss =  1.486, discriminator acc = 0.150
Iteration 599, source loss =  1.119, discriminator acc = 0.544
Iteration 699, source loss =  1.314, discriminator acc = 0.143
Iteration 799, source loss =  1.257, discriminator acc = 0.149
Iteration 899, source loss =  1.158, discriminator acc = 0.313
Iteration 999, source loss =  1.193, discriminator acc = 0.829
Iteration 1099, source loss =  1.071, discriminator acc = 0.335
Iteration 1199, source loss =  1.135, discriminator acc = 0.416
Iteration 1299, source loss =  1.050, discriminator acc = 1.000
Iteration 1399, source loss =  1.056, discriminator acc = 0.279
Iteration 1499, source loss =  0.914, discriminator acc = 0.999
Iteration 1599, source loss =  0.902, discriminator

  updates = self.state_updates


Iteration 99, source loss =  2.810, discriminator acc = 0.170
Iteration 199, source loss =  2.593, discriminator acc = 0.170
Iteration 299, source loss =  1.463, discriminator acc = 0.869
Iteration 399, source loss =  1.476, discriminator acc = 0.999
Iteration 499, source loss =  1.320, discriminator acc = 0.170
Iteration 599, source loss =  1.474, discriminator acc = 0.429
Iteration 699, source loss =  1.366, discriminator acc = 0.166
Iteration 799, source loss =  1.143, discriminator acc = 0.945
Iteration 899, source loss =  1.196, discriminator acc = 0.193
Iteration 999, source loss =  1.160, discriminator acc = 1.000
Iteration 1099, source loss =  1.105, discriminator acc = 0.624
Iteration 1199, source loss =  1.047, discriminator acc = 1.000
Iteration 1299, source loss =  1.029, discriminator acc = 0.201
Iteration 1399, source loss =  0.952, discriminator acc = 0.939
Iteration 1499, source loss =  0.931, discriminator acc = 0.265
Iteration 1599, source loss =  0.877, discriminator

  updates = self.state_updates


Iteration 99, source loss =  2.421, discriminator acc = 0.074
Iteration 199, source loss =  1.426, discriminator acc = 0.151
Iteration 299, source loss =  1.454, discriminator acc = 0.932
Iteration 399, source loss =  1.516, discriminator acc = 0.399
Iteration 499, source loss =  1.360, discriminator acc = 0.313
Iteration 599, source loss =  1.259, discriminator acc = 0.180
Iteration 699, source loss =  1.182, discriminator acc = 0.474
Iteration 799, source loss =  1.122, discriminator acc = 1.000
Iteration 899, source loss =  1.258, discriminator acc = 0.469
Iteration 999, source loss =  1.224, discriminator acc = 0.213
Iteration 1099, source loss =  1.150, discriminator acc = 0.170
Iteration 1199, source loss =  1.135, discriminator acc = 0.915
Iteration 1299, source loss =  1.111, discriminator acc = 0.850
Iteration 1399, source loss =  1.108, discriminator acc = 0.228
Iteration 1499, source loss =  1.098, discriminator acc = 0.922
Iteration 1599, source loss =  1.105, discriminator

  updates = self.state_updates


Iteration 99, source loss =  2.733, discriminator acc = 0.859
Iteration 199, source loss =  2.061, discriminator acc = 0.154
Iteration 299, source loss =  1.557, discriminator acc = 0.623
Iteration 399, source loss =  1.483, discriminator acc = 0.218
Iteration 499, source loss =  1.317, discriminator acc = 0.900
Iteration 599, source loss =  1.290, discriminator acc = 0.893
Iteration 699, source loss =  1.129, discriminator acc = 0.151
Iteration 799, source loss =  1.154, discriminator acc = 0.970
Iteration 899, source loss =  1.110, discriminator acc = 0.715
Iteration 999, source loss =  1.185, discriminator acc = 0.969
Iteration 1099, source loss =  1.284, discriminator acc = 0.156
Iteration 1199, source loss =  1.040, discriminator acc = 0.952
Iteration 1299, source loss =  1.010, discriminator acc = 0.322
Iteration 1399, source loss =  1.037, discriminator acc = 0.722
Iteration 1499, source loss =  0.953, discriminator acc = 0.789
Iteration 1599, source loss =  1.153, discriminator

  updates = self.state_updates


Iteration 99, source loss =  4.454, discriminator acc = 0.155
Iteration 199, source loss =  2.480, discriminator acc = 0.155
Iteration 299, source loss =  1.532, discriminator acc = 0.856
Iteration 399, source loss =  1.431, discriminator acc = 0.949
Iteration 499, source loss =  1.423, discriminator acc = 0.876
Iteration 599, source loss =  1.587, discriminator acc = 0.156
Iteration 699, source loss =  1.418, discriminator acc = 0.855
Iteration 799, source loss =  1.138, discriminator acc = 0.817
Iteration 899, source loss =  1.320, discriminator acc = 0.964
Iteration 999, source loss =  1.040, discriminator acc = 0.650
Iteration 1099, source loss =  1.033, discriminator acc = 0.808
Iteration 1199, source loss =  0.994, discriminator acc = 0.306
Iteration 1299, source loss =  0.987, discriminator acc = 0.426
Iteration 1399, source loss =  0.818, discriminator acc = 0.965
Iteration 1499, source loss =  0.787, discriminator acc = 0.978
Iteration 1599, source loss =  0.890, discriminator

  updates = self.state_updates


Iteration 99, source loss =  2.281, discriminator acc = 0.137
Iteration 199, source loss =  2.357, discriminator acc = 0.762
Iteration 299, source loss =  1.646, discriminator acc = 0.880
Iteration 399, source loss =  1.667, discriminator acc = 0.877
Iteration 499, source loss =  1.350, discriminator acc = 0.874
Iteration 599, source loss =  1.551, discriminator acc = 0.888
Iteration 699, source loss =  1.322, discriminator acc = 0.901
Iteration 799, source loss =  1.287, discriminator acc = 0.882
Iteration 899, source loss =  1.250, discriminator acc = 0.873
Iteration 999, source loss =  1.287, discriminator acc = 0.882
Iteration 1099, source loss =  1.243, discriminator acc = 0.995
Iteration 1199, source loss =  1.294, discriminator acc = 0.950
Iteration 1299, source loss =  1.455, discriminator acc = 0.499
Iteration 1399, source loss =  1.343, discriminator acc = 0.883
Iteration 1499, source loss =  1.047, discriminator acc = 0.857
Iteration 1599, source loss =  1.096, discriminator

  updates = self.state_updates


Iteration 99, source loss =  2.441, discriminator acc = 0.847
Iteration 199, source loss =  1.537, discriminator acc = 0.865
Iteration 299, source loss =  1.552, discriminator acc = 0.855
Iteration 399, source loss =  1.503, discriminator acc = 0.871
Iteration 499, source loss =  1.373, discriminator acc = 0.883
Iteration 599, source loss =  1.590, discriminator acc = 0.873
Iteration 699, source loss =  1.297, discriminator acc = 0.885
Iteration 799, source loss =  1.325, discriminator acc = 0.875
Iteration 899, source loss =  1.265, discriminator acc = 0.878
Iteration 999, source loss =  1.179, discriminator acc = 0.887
Iteration 1099, source loss =  1.233, discriminator acc = 0.974
Iteration 1199, source loss =  1.317, discriminator acc = 0.527
Iteration 1299, source loss =  1.381, discriminator acc = 0.987
Iteration 1399, source loss =  1.063, discriminator acc = 0.959
Iteration 1499, source loss =  1.107, discriminator acc = 0.373
Iteration 1599, source loss =  0.983, discriminator

In [8]:
# confidence = 0.95
# if BOOTSTRAP:
#     for alpha in BOOTSTRAP_ALPHAS:
#         bootstrap_d = {}
#         for sample_id in st_sample_id_l:
            
            
#             bootstrap_d[sample_id] = {}
#             for i, num in enumerate(numlist):
                
#                 acc = [metrics.roc_auc_score(*gen_pred_true(num, adata_spatialLIBD_d[SAMPLE_ID_N], pred)[::-1]) for pred in pred_sp_boostrap_d[alpha][SAMPLE_ID_N]]
#                 test_mean = np.mean(acc)
#                 t_value = ss.t.ppf((1 + confidence) / 2.0, df=BOOTSTRAP_ROUNDS - 1)

#                 sd = np.std(acc, ddof=1)
#                 se = sd / np.sqrt(BOOTSTRAP_ROUNDS)

#                 ci_length = t_value * se

#                 ci_lower = test_mean - ci_length
#                 ci_upper = test_mean + ci_length

#                 bootstrap_d[sample_id][num_to_ex_d[num]] = (ci_lower, test_mean, ci_upper)


#         bootstrap_df = pd.DataFrame.from_dict(bootstrap_d)
#         display(bootstrap_df)
#         bootstrap_df.to_csv(os.path.join(results_folder, f'bootstrap_alpha{alpha}.csv'))

## 4. Predict cell fraction of spots and visualization

In [9]:
# pred_sp_d, pred_sp_noda_d = {}, {}
# if TRAIN_USING_ALL_ST_SAMPLES:
#     for sample_id in st_sample_id_l:
#         pred_sp_d[sample_id] = clssmodel.predict(mat_sp_test_s_d[sample_id])
#         pred_sp_noda_d[sample_id] = clssmodel_noda.predict(mat_sp_test_s_d[sample_id])
# else:
#     for sample_id in st_sample_id_l:
#         pred_sp_d[sample_id] = clssmodel_d[sample_id].predict(
#             mat_sp_test_s_d[sample_id]
#         )
#         pred_sp_noda_d[sample_id] = clssmodel_noda_d[sample_id].predict(
#             mat_sp_test_s_d[sample_id]
#         )



In [10]:
# def plot_cellfraction(visnum, adata, pred_sp, ax=None):
#     """Plot predicted cell fraction for a given visnum"""
#     adata.obs["Pred_label"] = pred_sp[:, visnum]
#     # vmin = 0
#     # vmax = np.amax(pred_sp)

#     sc.pl.spatial(
#         adata,
#         img_key="hires",
#         color="Pred_label",
#         palette="Set1",
#         size=1.5,
#         legend_loc=None,
#         title=f"{sc_sub_dict[visnum]}",
#         spot_size=100,
#         show=False,
#         # vmin=vmin,
#         # vmax=vmax,
#         ax=ax,
#     )


In [11]:
# def plot_cell_layers(df):

#     layer_idx = df["spatialLIBD"].unique().astype(str)
#     samples = df["sample_id"].unique()
#     layer_idx.sort()
#     fig, ax = plt.subplots(
#         nrows=1,
#         ncols=len(samples),
#         figsize=(5 * len(samples), 5),
#         squeeze=False,
#         constrained_layout=True,
#     )

#     for idx, sample in enumerate(samples):
#         cells_of_samples = df[df["sample_id"] == sample]
#         for index in layer_idx:
#             cells_of_layer = cells_of_samples[cells_of_samples["spatialLIBD"] == index]
#             ax.flat[idx].scatter(
#                 cells_of_layer["X"], -cells_of_layer["Y"], label=index, s=17, marker="o"
#             )

#         ax.flat[idx].axis("equal")
#         ax.flat[idx].set_xticks([])
#         ax.flat[idx].set_yticks([])
#         ax.flat[idx].set_title(sample)

#     plt.legend()
#     plt.show()


In [12]:
# def plot_roc(visnum, adata, pred_sp, name, ax=None):
#     """Plot ROC for a given visnum"""

#     def layer_to_layer_number(x):
#         for char in x:
#             if char.isdigit():
#                 if int(char) in Ex_to_L_d[num_to_ex_d[visnum]]:
#                     return 1
#         return 0

#     y_pred = pred_sp[:, visnum]
#     y_true = adata.obs["spatialLIBD"].map(layer_to_layer_number).fillna(0)
#     # print(y_true)
#     # print(y_true.isna().sum())
#     RocCurveDisplay.from_predictions(y_true=y_true, y_pred=y_pred, name=name, ax=ax)


In [13]:
# # plot_cell_layers(adata_spatialLIBD_151673.obs)

# fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5, 5), constrained_layout=True)

# sc.pl.spatial(
#     adata_spatialLIBD_d[SAMPLE_ID_N],
#     img_key=None,
#     color="spatialLIBD",
#     palette="Accent_r",
#     size=1.5,
#     title=SAMPLE_ID_N,
#     # legend_loc = 4,
#     spot_size=100,
#     show=False,
#     ax=ax,
# )

# ax.axis("equal")
# ax.set_xlabel("")
# ax.set_ylabel("")

# fig.show()


In [14]:
# fig, ax = plt.subplots(2, 5, figsize=(20, 8), constrained_layout=True)

# for i, num in enumerate(numlist):
#     plot_cellfraction(
#         num, adata_spatialLIBD_d[SAMPLE_ID_N], pred_sp_d[SAMPLE_ID_N], ax.flat[i]
#     )
#     ax.flat[i].axis("equal")
#     ax.flat[i].set_xlabel("")
#     ax.flat[i].set_ylabel("")

# fig.show()
# # plt.close()

# fig, ax = plt.subplots(
#     2, 5, figsize=(20, 8), constrained_layout=True, sharex=True, sharey=True
# )

# for i, num in enumerate(numlist):
#     plot_roc(
#         num,
#         adata_spatialLIBD_d[SAMPLE_ID_N],
#         pred_sp_d[SAMPLE_ID_N],
#         "CellDART",
#         ax.flat[i],
#     )
#     plot_roc(
#         num,
#         adata_spatialLIBD_d[SAMPLE_ID_N],
#         pred_sp_noda_d[SAMPLE_ID_N],
#         "NN_wo_da",
#         ax.flat[i],
#     )
#     ax.flat[i].plot([0, 1], [0, 1], transform=ax.flat[i].transAxes, ls="--", color="k")
#     ax.flat[i].set_aspect("equal")
#     ax.flat[i].set_xlim([0, 1])
#     ax.flat[i].set_ylim([0, 1])

#     ax.flat[i].set_title(f"{sc_sub_dict[num]}")

#     if i >= len(numlist) - 5:
#         ax.flat[i].set_xlabel("FPR")
#     else:
#         ax.flat[i].set_xlabel("")
#     if i % 5 == 0:
#         ax.flat[i].set_ylabel("TPR")
#     else:
#         ax.flat[i].set_ylabel("")

# fig.show()
# # plt.close()


- cf. Prediction of Mixture (pseudospots)


In [15]:
# if TRAIN_USING_ALL_ST_SAMPLES:
#     pred_mix = clssmodel.predict(sc_mix_test_s)
# else:
#     pred_mix = clssmodel_d[SAMPLE_ID_N].predict(sc_mix_test_s)


# cell_type_nums = sc_sub_dict.keys()
# nrows = ceil(len(cell_type_nums) / 5)

# line_kws = {"color": "tab:orange"}
# scatter_kws = {"s": 5}

# props = dict(facecolor="w", alpha=0.5)

# fig, ax = plt.subplots(
#     nrows,
#     5,
#     figsize=(20, 4 * nrows),
#     constrained_layout=True,
#     sharex=False,
#     sharey=True,
# )
# for i, num in enumerate(cell_type_nums):
#     sns.regplot(
#         x=pred_mix[:, num],
#         y=lab_mix_test[:, num],
#         line_kws=line_kws,
#         scatter_kws=scatter_kws,
#         ax=ax.flat[i],
#     ).set_title(sc_sub_dict[num])
#     ax.flat[i].set_aspect("equal")

#     ax.flat[i].set_xlabel("Predicted Proportion")
#     if i % 5 == 0:
#         ax.flat[i].set_ylabel("True Proportion")
#     else:
#         ax.flat[i].set_ylabel("")
#     ax.flat[i].set_xlim([0, 1])
#     ax.flat[i].set_ylim([0, 1])

#     textstr = f"MSE: {mean_squared_error(pred_mix[:,num], lab_mix_test[:,num]):.5f}"

#     # place a text box in upper left in axes coords
#     ax.flat[i].text(
#         0.95,
#         0.05,
#         textstr,
#         transform=ax.flat[i].transAxes,
#         verticalalignment="bottom",
#         horizontalalignment="right",
#         bbox=props,
#     )

# for i in range(len(cell_type_nums), nrows * 5):
#     ax.flat[i].axis("off")

# plt.show()


In [16]:
# print(
#     "\n".join(
#         f"{m.__name__} {m.__version__}"
#         for m in globals().values()
#         if getattr(m, "__version__", None)
#     )
# )
