# CellDART Example Code: mouse brain 
## (10x Visium of anterior mouse brain + scRNA-seq data of mouse brain)

In [1]:
import glob
import os
import pickle
from math import ceil

import anndata as ad
import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import scanpy as sc
import seaborn as sns
import scipy.stats as ss

from sklearn.model_selection import train_test_split
from sklearn.metrics import RocCurveDisplay
from sklearn.metrics import mean_squared_error
from sklearn import metrics

from CellDART import da_cellfraction
from CellDART.utils import random_mix

from src.utils.data_loading import load_spatial, load_sc, get_selected_dir

from tqdm.autonotebook import tqdm


  from tqdm.autonotebook import tqdm


In [2]:
import tensorflow as tf   # TensorFlow registers PluggableDevices here.
tf.config.list_physical_devices() 

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [3]:
TRAIN_USING_ALL_ST_SAMPLES = False
N_MARKERS = 20
ALL_GENES = False

ST_SPLIT = False
N_SPOTS = 20000
N_MIX = 8
SCALER_NAME = "celldart"


SAMPLE_ID_N = "151673"
INITIAL_TRAIN_EPOCHS = 10


BATCH_SIZE = 512
ALPHA = 0.6
ALPHA_LR = 5
N_ITER = 3000

DATA_DIR = "../AGrEDA/data"

BOOTSTRAP = False
BOOTSTRAP_ROUNDS = 10
BOOTSTRAP_ALPHAS = [0.6, 1/0.6]

MODEL_NAME = 'CellDART'
MODEL_VERSION = "V1"

In [4]:
model_folder = os.path.join("model", MODEL_NAME, MODEL_VERSION)

if not os.path.isdir(model_folder):
    os.makedirs(model_folder)
    print(model_folder)


In [5]:
def get_selected_dir(data_dir, n_markers=20, all_genes=False):
    if all_genes:
        return os.path.join(data_dir, "preprocessed", "all")
    else:
        return os.path.join(data_dir, "preprocessed", f"{n_markers}markers")

## 1. Data load
### load scanpy data - 10x datasets

In [6]:
# Load spatial data
mat_sp_d, mat_sp_train, st_sample_id_l = load_spatial(
    get_selected_dir(DATA_DIR, N_MARKERS, ALL_GENES),
    SCALER_NAME,
    train_using_all_st_samples=TRAIN_USING_ALL_ST_SAMPLES,
    st_split=ST_SPLIT,
)

# Load sc data
sc_mix_d, lab_mix_d, sc_sub_dict, sc_sub_dict2 = load_sc(
    get_selected_dir(DATA_DIR, N_MARKERS, ALL_GENES),
    SCALER_NAME,
    n_mix=N_MIX,
    n_spots=N_SPOTS,
)


In [7]:
advtrain_folder = os.path.join(model_folder, "advtrain")
pretrain_folder = os.path.join(model_folder, "pretrain")
if not os.path.isdir(advtrain_folder):
    os.makedirs(advtrain_folder)
if not os.path.isdir(pretrain_folder):
    os.makedirs(pretrain_folder)

In [9]:
if TRAIN_USING_ALL_ST_SAMPLES:
    print(f"Adversarial training for all ST slides")
    embs, clssmodel, clssmodel_noda = da_cellfraction.train(
        sc_mix_d["train"],
        lab_mix_d["train"],
        mat_sp_train_s,
        alpha=ALPHA,
        alpha_lr=5,
        emb_dim=64,
        batch_size=BATCH_SIZE,
        n_iterations=N_ITER,
        initial_train=True,
        initial_train_epochs=INITIAL_TRAIN_EPOCHS,
    )
elif BOOTSTRAP:
    pred_sp_boostrap_d = {}
    
    outer = tqdm(total=len(BOOTSTRAP_ALPHAS), desc="Alphas", position=0)
    inner1 = tqdm(total=len(st_sample_id_l), desc=f"Sample", position=1)
    inner2 = tqdm(total=BOOTSTRAP_ROUNDS, desc=f"Bootstrap #", position=2)
    for alpha in BOOTSTRAP_ALPHAS:
        
        inner1.refresh()  # force print final state
        inner1.reset()  # reuse bar

        
        pred_sp_boostrap_d[alpha] = {}
        

        for sample_id in st_sample_id_l:
            inner2.refresh()  # force print final state
            inner2.reset()  # reuse bar
            
            pred_sp_boostrap_d[alpha][sample_id] = []
            for i in range(BOOTSTRAP_ROUNDS):
                print(f"Adversarial training for ST slide {sample_id}: ")
                embs, clssmodel, _ = da_cellfraction.train(
                    sc_mix_d["train"],
                    lab_mix_d["train"],
                    mat_sp_d["train"][sample_id],
                    alpha=alpha,
                    alpha_lr=5,
                    emb_dim=64,
                    batch_size=BATCH_SIZE,
                    n_iterations=N_ITER,
                    initial_train=True,
                    initial_train_epochs=10,
                    seed = i,
                )

                pred_sp_boostrap_d[alpha][sample_id].append(clssmodel.predict(
                     mat_sp_d["train"][sample_id]
                ))
                inner2.update(1)
            inner1.update(1)
        outer.update(1)

else:
    # embs_d, clssmodel_d, clssmodel_noda_d = {}, {}, {}
    for sample_id in st_sample_id_l:
        print(f"Adversarial training for ST slide {sample_id}: ")
        if not os.path.isdir(os.path.join(advtrain_folder, sample_id)):
            os.makedirs(os.path.join(advtrain_folder, sample_id))
        if not os.path.isdir(os.path.join(pretrain_folder, sample_id)):
            os.makedirs(os.path.join(pretrain_folder, sample_id))
        embs, embs_noda, clssmodel, clssmodel_noda = da_cellfraction.train(
            sc_mix_d["train"],
            lab_mix_d["train"],
            mat_sp_d[sample_id]["train"],
            alpha=ALPHA,
            alpha_lr=5,
            emb_dim=64,
            batch_size=BATCH_SIZE,
            n_iterations=N_ITER,
            initial_train=True,
            initial_train_epochs=10,
        )
        # embs_d[sample_id] = embs
        # clssmodel_d[sample_id] = clssmodel
        # clssmodel_noda_d[sample_id] = clssmodel_noda

        # Save model
        clssmodel_noda.save(os.path.join(pretrain_folder, sample_id, "final_model"))
        clssmodel.save(os.path.join(advtrain_folder, sample_id, "final_model"))

        embs_noda.save(os.path.join(pretrain_folder, sample_id, "embs"))
        embs.save(os.path.join(advtrain_folder, sample_id, "embs"))


Adversarial training for ST slide 151507: 
Train on 20000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
initial_train_done


  updates = self.state_updates


Iteration 99, source loss =  4.918, discriminator acc = 0.174
Iteration 199, source loss =  2.263, discriminator acc = 0.463
Iteration 299, source loss =  1.516, discriminator acc = 0.173
Iteration 399, source loss =  1.667, discriminator acc = 0.174
Iteration 499, source loss =  1.583, discriminator acc = 0.174
Iteration 599, source loss =  1.393, discriminator acc = 0.174
Iteration 699, source loss =  1.592, discriminator acc = 0.234
Iteration 799, source loss =  1.433, discriminator acc = 0.828
Iteration 899, source loss =  1.264, discriminator acc = 0.244
Iteration 999, source loss =  1.579, discriminator acc = 0.936
Iteration 1099, source loss =  1.166, discriminator acc = 0.404
Iteration 1199, source loss =  1.632, discriminator acc = 0.852
Iteration 1299, source loss =  1.121, discriminator acc = 0.254
Iteration 1399, source loss =  1.172, discriminator acc = 0.790
Iteration 1499, source loss =  1.006, discriminator acc = 0.813
Iteration 1599, source loss =  0.947, discriminator

  updates = self.state_updates


Iteration 99, source loss =  3.589, discriminator acc = 0.180
Iteration 199, source loss =  2.009, discriminator acc = 0.232
Iteration 299, source loss =  1.777, discriminator acc = 0.887
Iteration 399, source loss =  1.220, discriminator acc = 0.948
Iteration 499, source loss =  1.608, discriminator acc = 0.732
Iteration 599, source loss =  1.313, discriminator acc = 0.855
Iteration 699, source loss =  1.117, discriminator acc = 0.964
Iteration 799, source loss =  1.088, discriminator acc = 0.997
Iteration 899, source loss =  0.959, discriminator acc = 0.714
Iteration 999, source loss =  0.956, discriminator acc = 0.117
Iteration 1099, source loss =  0.788, discriminator acc = 0.931
Iteration 1199, source loss =  0.786, discriminator acc = 0.984
Iteration 1299, source loss =  0.786, discriminator acc = 0.055
Iteration 1399, source loss =  0.760, discriminator acc = 0.376
Iteration 1499, source loss =  0.801, discriminator acc = 0.106
Iteration 1599, source loss =  0.860, discriminator

  updates = self.state_updates


Iteration 99, source loss =  4.164, discriminator acc = 0.221
Iteration 199, source loss =  1.654, discriminator acc = 0.814
Iteration 299, source loss =  1.499, discriminator acc = 0.810
Iteration 399, source loss =  1.306, discriminator acc = 0.817
Iteration 499, source loss =  1.822, discriminator acc = 0.823
Iteration 599, source loss =  1.445, discriminator acc = 0.835
Iteration 699, source loss =  1.271, discriminator acc = 0.844
Iteration 799, source loss =  1.172, discriminator acc = 0.835
Iteration 899, source loss =  1.116, discriminator acc = 0.827
Iteration 999, source loss =  1.169, discriminator acc = 0.855
Iteration 1099, source loss =  1.285, discriminator acc = 0.863
Iteration 1199, source loss =  1.247, discriminator acc = 0.832
Iteration 1299, source loss =  1.199, discriminator acc = 0.849
Iteration 1399, source loss =  1.228, discriminator acc = 0.932
Iteration 1499, source loss =  1.229, discriminator acc = 0.932
Iteration 1599, source loss =  1.208, discriminator

  updates = self.state_updates


Iteration 99, source loss =  2.185, discriminator acc = 0.188
Iteration 199, source loss =  2.618, discriminator acc = 0.188
Iteration 299, source loss =  1.700, discriminator acc = 0.894
Iteration 399, source loss =  1.621, discriminator acc = 0.855
Iteration 499, source loss =  1.318, discriminator acc = 0.403
Iteration 599, source loss =  1.422, discriminator acc = 0.839
Iteration 699, source loss =  1.420, discriminator acc = 0.933
Iteration 799, source loss =  1.202, discriminator acc = 0.855
Iteration 899, source loss =  1.177, discriminator acc = 0.887
Iteration 999, source loss =  1.350, discriminator acc = 0.574
Iteration 1099, source loss =  0.981, discriminator acc = 0.837
Iteration 1199, source loss =  0.905, discriminator acc = 0.507
Iteration 1299, source loss =  0.934, discriminator acc = 0.558
Iteration 1399, source loss =  0.831, discriminator acc = 0.874
Iteration 1499, source loss =  0.846, discriminator acc = 0.867
Iteration 1599, source loss =  0.733, discriminator

  updates = self.state_updates


Iteration 99, source loss =  3.713, discriminator acc = 0.060
Iteration 199, source loss =  2.398, discriminator acc = 0.161
Iteration 299, source loss =  1.368, discriminator acc = 0.998
Iteration 399, source loss =  1.573, discriminator acc = 0.414
Iteration 499, source loss =  1.352, discriminator acc = 0.966
Iteration 599, source loss =  1.101, discriminator acc = 0.996
Iteration 699, source loss =  1.117, discriminator acc = 0.219
Iteration 799, source loss =  0.916, discriminator acc = 0.374
Iteration 899, source loss =  0.887, discriminator acc = 0.807
Iteration 999, source loss =  0.917, discriminator acc = 0.915
Iteration 1099, source loss =  0.819, discriminator acc = 0.639
Iteration 1199, source loss =  0.894, discriminator acc = 0.392
Iteration 1299, source loss =  0.767, discriminator acc = 0.518
Iteration 1399, source loss =  0.697, discriminator acc = 0.917
Iteration 1499, source loss =  0.725, discriminator acc = 0.945
Iteration 1599, source loss =  0.828, discriminator

  updates = self.state_updates


Iteration 99, source loss =  3.527, discriminator acc = 0.877
Iteration 199, source loss =  1.297, discriminator acc = 0.857
Iteration 299, source loss =  1.680, discriminator acc = 0.851
Iteration 399, source loss =  1.277, discriminator acc = 0.851
Iteration 499, source loss =  1.387, discriminator acc = 0.853
Iteration 599, source loss =  1.212, discriminator acc = 0.852
Iteration 699, source loss =  1.122, discriminator acc = 0.852
Iteration 799, source loss =  1.325, discriminator acc = 0.998
Iteration 899, source loss =  1.202, discriminator acc = 0.830
Iteration 999, source loss =  1.052, discriminator acc = 0.931
Iteration 1099, source loss =  1.011, discriminator acc = 0.990
Iteration 1199, source loss =  0.950, discriminator acc = 0.743
Iteration 1299, source loss =  0.930, discriminator acc = 0.158
Iteration 1399, source loss =  0.806, discriminator acc = 0.718
Iteration 1499, source loss =  0.784, discriminator acc = 0.979
Iteration 1599, source loss =  0.791, discriminator

  updates = self.state_updates


Iteration 99, source loss =  3.681, discriminator acc = 0.170
Iteration 199, source loss =  2.170, discriminator acc = 0.170
Iteration 299, source loss =  1.765, discriminator acc = 0.171
Iteration 399, source loss =  1.510, discriminator acc = 0.989
Iteration 499, source loss =  1.448, discriminator acc = 0.751
Iteration 599, source loss =  1.100, discriminator acc = 0.324
Iteration 699, source loss =  1.154, discriminator acc = 0.580
Iteration 799, source loss =  1.247, discriminator acc = 0.152
Iteration 899, source loss =  1.061, discriminator acc = 0.997
Iteration 999, source loss =  1.120, discriminator acc = 0.099
Iteration 1099, source loss =  1.012, discriminator acc = 0.945
Iteration 1199, source loss =  0.988, discriminator acc = 0.261
Iteration 1299, source loss =  1.008, discriminator acc = 0.515
Iteration 1399, source loss =  1.040, discriminator acc = 0.248
Iteration 1499, source loss =  0.975, discriminator acc = 0.759
Iteration 1599, source loss =  1.021, discriminator

  updates = self.state_updates


Iteration 99, source loss =  3.466, discriminator acc = 0.006
Iteration 199, source loss =  2.325, discriminator acc = 0.838
Iteration 299, source loss =  1.686, discriminator acc = 0.836
Iteration 399, source loss =  1.448, discriminator acc = 0.837
Iteration 499, source loss =  1.371, discriminator acc = 0.992
Iteration 599, source loss =  1.693, discriminator acc = 0.690
Iteration 699, source loss =  1.463, discriminator acc = 0.999
Iteration 799, source loss =  1.195, discriminator acc = 0.828
Iteration 899, source loss =  1.175, discriminator acc = 0.942
Iteration 999, source loss =  1.111, discriminator acc = 0.759
Iteration 1099, source loss =  1.195, discriminator acc = 0.830
Iteration 1199, source loss =  1.023, discriminator acc = 0.667
Iteration 1299, source loss =  0.926, discriminator acc = 0.718
Iteration 1399, source loss =  0.859, discriminator acc = 0.701
Iteration 1499, source loss =  0.811, discriminator acc = 0.727
Iteration 1599, source loss =  0.793, discriminator

  updates = self.state_updates


Iteration 99, source loss =  2.913, discriminator acc = 0.879
Iteration 199, source loss =  1.648, discriminator acc = 0.865
Iteration 299, source loss =  1.569, discriminator acc = 0.867
Iteration 399, source loss =  1.516, discriminator acc = 0.872
Iteration 499, source loss =  1.601, discriminator acc = 0.870
Iteration 599, source loss =  1.291, discriminator acc = 0.872
Iteration 699, source loss =  1.313, discriminator acc = 0.880
Iteration 799, source loss =  1.225, discriminator acc = 0.893
Iteration 899, source loss =  1.385, discriminator acc = 0.882
Iteration 999, source loss =  1.268, discriminator acc = 0.862
Iteration 1099, source loss =  1.505, discriminator acc = 0.859
Iteration 1199, source loss =  1.432, discriminator acc = 0.980
Iteration 1299, source loss =  1.388, discriminator acc = 0.897
Iteration 1399, source loss =  1.250, discriminator acc = 0.804
Iteration 1499, source loss =  1.306, discriminator acc = 0.830
Iteration 1599, source loss =  1.043, discriminator

  updates = self.state_updates


Iteration 99, source loss =  1.965, discriminator acc = 0.155
Iteration 199, source loss =  2.012, discriminator acc = 0.865
Iteration 299, source loss =  1.902, discriminator acc = 0.865
Iteration 399, source loss =  1.483, discriminator acc = 0.852
Iteration 499, source loss =  1.588, discriminator acc = 0.865
Iteration 599, source loss =  1.532, discriminator acc = 0.873
Iteration 699, source loss =  1.283, discriminator acc = 0.874
Iteration 799, source loss =  1.294, discriminator acc = 0.878
Iteration 899, source loss =  1.209, discriminator acc = 0.846
Iteration 999, source loss =  1.249, discriminator acc = 0.866
Iteration 1099, source loss =  1.163, discriminator acc = 0.878
Iteration 1199, source loss =  1.253, discriminator acc = 0.954
Iteration 1299, source loss =  1.286, discriminator acc = 0.863
Iteration 1399, source loss =  1.358, discriminator acc = 0.851
Iteration 1499, source loss =  1.245, discriminator acc = 0.868
Iteration 1599, source loss =  1.233, discriminator

  updates = self.state_updates


Iteration 99, source loss =  2.303, discriminator acc = 0.865
Iteration 199, source loss =  2.572, discriminator acc = 0.863
Iteration 299, source loss =  1.606, discriminator acc = 0.860
Iteration 399, source loss =  1.874, discriminator acc = 0.871
Iteration 499, source loss =  1.294, discriminator acc = 0.882
Iteration 599, source loss =  1.523, discriminator acc = 0.873
Iteration 699, source loss =  1.451, discriminator acc = 0.866
Iteration 799, source loss =  1.544, discriminator acc = 0.884
Iteration 899, source loss =  1.544, discriminator acc = 0.884
Iteration 999, source loss =  1.328, discriminator acc = 0.881
Iteration 1099, source loss =  1.350, discriminator acc = 0.887
Iteration 1199, source loss =  1.423, discriminator acc = 0.880
Iteration 1299, source loss =  1.338, discriminator acc = 0.884
Iteration 1399, source loss =  1.201, discriminator acc = 0.888
Iteration 1499, source loss =  1.299, discriminator acc = 0.877
Iteration 1599, source loss =  1.260, discriminator

  updates = self.state_updates


Iteration 99, source loss =  4.264, discriminator acc = 0.072
Iteration 199, source loss =  1.424, discriminator acc = 0.100
Iteration 299, source loss =  1.700, discriminator acc = 0.160
Iteration 399, source loss =  1.227, discriminator acc = 0.975
Iteration 499, source loss =  1.461, discriminator acc = 0.937
Iteration 599, source loss =  1.639, discriminator acc = 0.136
Iteration 699, source loss =  1.269, discriminator acc = 1.000
Iteration 799, source loss =  1.082, discriminator acc = 0.627
Iteration 899, source loss =  1.180, discriminator acc = 0.943
Iteration 999, source loss =  1.098, discriminator acc = 0.148
Iteration 1099, source loss =  1.017, discriminator acc = 0.953
Iteration 1199, source loss =  1.008, discriminator acc = 0.643
Iteration 1299, source loss =  0.909, discriminator acc = 0.896
Iteration 1399, source loss =  1.012, discriminator acc = 0.791
Iteration 1499, source loss =  0.958, discriminator acc = 0.362
Iteration 1599, source loss =  0.907, discriminator

In [None]:
# confidence = 0.95
# if BOOTSTRAP:
#     for alpha in BOOTSTRAP_ALPHAS:
#         bootstrap_d = {}
#         for sample_id in st_sample_id_l:
            
            
#             bootstrap_d[sample_id] = {}
#             for i, num in enumerate(numlist):
                
#                 acc = [metrics.roc_auc_score(*gen_pred_true(num, adata_spatialLIBD_d[SAMPLE_ID_N], pred)[::-1]) for pred in pred_sp_boostrap_d[alpha][SAMPLE_ID_N]]
#                 test_mean = np.mean(acc)
#                 t_value = ss.t.ppf((1 + confidence) / 2.0, df=BOOTSTRAP_ROUNDS - 1)

#                 sd = np.std(acc, ddof=1)
#                 se = sd / np.sqrt(BOOTSTRAP_ROUNDS)

#                 ci_length = t_value * se

#                 ci_lower = test_mean - ci_length
#                 ci_upper = test_mean + ci_length

#                 bootstrap_d[sample_id][num_to_ex_d[num]] = (ci_lower, test_mean, ci_upper)


#         bootstrap_df = pd.DataFrame.from_dict(bootstrap_d)
#         display(bootstrap_df)
#         bootstrap_df.to_csv(os.path.join(results_folder, f'bootstrap_alpha{alpha}.csv'))

## 4. Predict cell fraction of spots and visualization

In [None]:
# pred_sp_d, pred_sp_noda_d = {}, {}
# if TRAIN_USING_ALL_ST_SAMPLES:
#     for sample_id in st_sample_id_l:
#         pred_sp_d[sample_id] = clssmodel.predict(mat_sp_test_s_d[sample_id])
#         pred_sp_noda_d[sample_id] = clssmodel_noda.predict(mat_sp_test_s_d[sample_id])
# else:
#     for sample_id in st_sample_id_l:
#         pred_sp_d[sample_id] = clssmodel_d[sample_id].predict(
#             mat_sp_test_s_d[sample_id]
#         )
#         pred_sp_noda_d[sample_id] = clssmodel_noda_d[sample_id].predict(
#             mat_sp_test_s_d[sample_id]
#         )



In [None]:
# def plot_cellfraction(visnum, adata, pred_sp, ax=None):
#     """Plot predicted cell fraction for a given visnum"""
#     adata.obs["Pred_label"] = pred_sp[:, visnum]
#     # vmin = 0
#     # vmax = np.amax(pred_sp)

#     sc.pl.spatial(
#         adata,
#         img_key="hires",
#         color="Pred_label",
#         palette="Set1",
#         size=1.5,
#         legend_loc=None,
#         title=f"{sc_sub_dict[visnum]}",
#         spot_size=100,
#         show=False,
#         # vmin=vmin,
#         # vmax=vmax,
#         ax=ax,
#     )


In [None]:
# def plot_cell_layers(df):

#     layer_idx = df["spatialLIBD"].unique().astype(str)
#     samples = df["sample_id"].unique()
#     layer_idx.sort()
#     fig, ax = plt.subplots(
#         nrows=1,
#         ncols=len(samples),
#         figsize=(5 * len(samples), 5),
#         squeeze=False,
#         constrained_layout=True,
#     )

#     for idx, sample in enumerate(samples):
#         cells_of_samples = df[df["sample_id"] == sample]
#         for index in layer_idx:
#             cells_of_layer = cells_of_samples[cells_of_samples["spatialLIBD"] == index]
#             ax.flat[idx].scatter(
#                 cells_of_layer["X"], -cells_of_layer["Y"], label=index, s=17, marker="o"
#             )

#         ax.flat[idx].axis("equal")
#         ax.flat[idx].set_xticks([])
#         ax.flat[idx].set_yticks([])
#         ax.flat[idx].set_title(sample)

#     plt.legend()
#     plt.show()


In [None]:
# def plot_roc(visnum, adata, pred_sp, name, ax=None):
#     """Plot ROC for a given visnum"""

#     def layer_to_layer_number(x):
#         for char in x:
#             if char.isdigit():
#                 if int(char) in Ex_to_L_d[num_to_ex_d[visnum]]:
#                     return 1
#         return 0

#     y_pred = pred_sp[:, visnum]
#     y_true = adata.obs["spatialLIBD"].map(layer_to_layer_number).fillna(0)
#     # print(y_true)
#     # print(y_true.isna().sum())
#     RocCurveDisplay.from_predictions(y_true=y_true, y_pred=y_pred, name=name, ax=ax)


In [None]:
# # plot_cell_layers(adata_spatialLIBD_151673.obs)

# fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5, 5), constrained_layout=True)

# sc.pl.spatial(
#     adata_spatialLIBD_d[SAMPLE_ID_N],
#     img_key=None,
#     color="spatialLIBD",
#     palette="Accent_r",
#     size=1.5,
#     title=SAMPLE_ID_N,
#     # legend_loc = 4,
#     spot_size=100,
#     show=False,
#     ax=ax,
# )

# ax.axis("equal")
# ax.set_xlabel("")
# ax.set_ylabel("")

# fig.show()


In [None]:
# fig, ax = plt.subplots(2, 5, figsize=(20, 8), constrained_layout=True)

# for i, num in enumerate(numlist):
#     plot_cellfraction(
#         num, adata_spatialLIBD_d[SAMPLE_ID_N], pred_sp_d[SAMPLE_ID_N], ax.flat[i]
#     )
#     ax.flat[i].axis("equal")
#     ax.flat[i].set_xlabel("")
#     ax.flat[i].set_ylabel("")

# fig.show()
# # plt.close()

# fig, ax = plt.subplots(
#     2, 5, figsize=(20, 8), constrained_layout=True, sharex=True, sharey=True
# )

# for i, num in enumerate(numlist):
#     plot_roc(
#         num,
#         adata_spatialLIBD_d[SAMPLE_ID_N],
#         pred_sp_d[SAMPLE_ID_N],
#         "CellDART",
#         ax.flat[i],
#     )
#     plot_roc(
#         num,
#         adata_spatialLIBD_d[SAMPLE_ID_N],
#         pred_sp_noda_d[SAMPLE_ID_N],
#         "NN_wo_da",
#         ax.flat[i],
#     )
#     ax.flat[i].plot([0, 1], [0, 1], transform=ax.flat[i].transAxes, ls="--", color="k")
#     ax.flat[i].set_aspect("equal")
#     ax.flat[i].set_xlim([0, 1])
#     ax.flat[i].set_ylim([0, 1])

#     ax.flat[i].set_title(f"{sc_sub_dict[num]}")

#     if i >= len(numlist) - 5:
#         ax.flat[i].set_xlabel("FPR")
#     else:
#         ax.flat[i].set_xlabel("")
#     if i % 5 == 0:
#         ax.flat[i].set_ylabel("TPR")
#     else:
#         ax.flat[i].set_ylabel("")

# fig.show()
# # plt.close()


- cf. Prediction of Mixture (pseudospots)


In [None]:
# if TRAIN_USING_ALL_ST_SAMPLES:
#     pred_mix = clssmodel.predict(sc_mix_test_s)
# else:
#     pred_mix = clssmodel_d[SAMPLE_ID_N].predict(sc_mix_test_s)


# cell_type_nums = sc_sub_dict.keys()
# nrows = ceil(len(cell_type_nums) / 5)

# line_kws = {"color": "tab:orange"}
# scatter_kws = {"s": 5}

# props = dict(facecolor="w", alpha=0.5)

# fig, ax = plt.subplots(
#     nrows,
#     5,
#     figsize=(20, 4 * nrows),
#     constrained_layout=True,
#     sharex=False,
#     sharey=True,
# )
# for i, num in enumerate(cell_type_nums):
#     sns.regplot(
#         x=pred_mix[:, num],
#         y=lab_mix_test[:, num],
#         line_kws=line_kws,
#         scatter_kws=scatter_kws,
#         ax=ax.flat[i],
#     ).set_title(sc_sub_dict[num])
#     ax.flat[i].set_aspect("equal")

#     ax.flat[i].set_xlabel("Predicted Proportion")
#     if i % 5 == 0:
#         ax.flat[i].set_ylabel("True Proportion")
#     else:
#         ax.flat[i].set_ylabel("")
#     ax.flat[i].set_xlim([0, 1])
#     ax.flat[i].set_ylim([0, 1])

#     textstr = f"MSE: {mean_squared_error(pred_mix[:,num], lab_mix_test[:,num]):.5f}"

#     # place a text box in upper left in axes coords
#     ax.flat[i].text(
#         0.95,
#         0.05,
#         textstr,
#         transform=ax.flat[i].transAxes,
#         verticalalignment="bottom",
#         horizontalalignment="right",
#         bbox=props,
#     )

# for i in range(len(cell_type_nums), nrows * 5):
#     ax.flat[i].axis("off")

# plt.show()


In [None]:
# print(
#     "\n".join(
#         f"{m.__name__} {m.__version__}"
#         for m in globals().values()
#         if getattr(m, "__version__", None)
#     )
# )
