# CellDART Example Code: mouse brain 
## (10x Visium of anterior mouse brain + scRNA-seq data of mouse brain)

In [1]:
import os

import numpy as np
import tensorflow as tf  # TensorFlow registers PluggableDevices here.
from tqdm.autonotebook import tqdm
import yaml

from CellDART import da_cellfraction
from src.da_utils import data_loading




  from tqdm.autonotebook import tqdm


In [2]:
tf.config.list_physical_devices()


[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [3]:
PS_SEEDS = (3679, 343, 25, 234, 98098)
MODEL_SEEDS = (2353, 24385, 284, 86322, 98237)
MODEL_DIR = "model_FINAL"
SEED_OVERRIDE = None

CONFIGS_DIR = "configs"
CONFIG_FNAME = "celldart-final-pdac-ht.yml"

# BOOTSTRAP = False
# BOOTSTRAP_ROUNDS = 10
# BOOTSTRAP_ALPHAS = [0.6, 1 / 0.6]

MODEL_NAME = "CellDART_original"


In [4]:
with open(os.path.join(CONFIGS_DIR, MODEL_NAME, CONFIG_FNAME), "r") as f:
    config = yaml.safe_load(f)

lib_params = config["lib_params"]
data_params = config["data_params"]
model_params = config["model_params"]
train_params = config["train_params"]

rewrite_config = False
if not "pretraining" in train_params:
    train_params["pretraining"] = True
    rewrite_config = True
if not "lr" in train_params:
    train_params["lr"] = 0.001
    rewrite_config = True

if rewrite_config:
    with open(os.path.join(CONFIGS_DIR, MODEL_NAME, CONFIG_FNAME), "w") as f:
        yaml.safe_dump(config, f)

tqdm.write(yaml.safe_dump(config))


data_params:
  all_genes: false
  data_dir: ../AGrEDA/data
  dset: pdac
  n_markers: 80
  n_mix: 70
  n_spots: 100000
  one_model: true
  samp_split: false
  sc_id: CA001063
  scaler_name: minmax
  st_id: GSE111672
  st_split: false
lib_params:
  manual_seed: 2123428735
model_params:
  celldart_kwargs:
    bn_momentum: 0.01
    emb_dim: 32
  model_version: gen_pdac-18486
train_params:
  alpha: 0.6
  alpha_lr: 5
  batch_size: 128
  initial_train_epochs: 10
  lr: 0.01
  n_iter: 15000
  pretraining: true
  reverse_val: false



## 1. Data load
### load scanpy data - 10x datasets

In [5]:
def train(ps_seed, model_seed):
    print(f"PS seed: {ps_seed}, model seed: {model_seed}")

    model_folder = data_loading.get_model_rel_path(
        MODEL_NAME,
        model_params["model_version"],
        lib_seed_path=str(model_seed),
        **data_params,
    )

    model_folder = os.path.join(MODEL_DIR, model_folder)
    if not os.path.isdir(model_folder):
        os.makedirs(model_folder)
        print(model_folder)
        
    selected_dir = data_loading.get_selected_dir(
        data_loading.get_dset_dir(
            data_params["data_dir"],
            dset=data_params.get("dset", "dlpfc"),
        ),
        **data_params,
    )
    # Load spatial data
    mat_sp_d, mat_sp_meta_d, st_sample_id_l = data_loading.load_spatial(
        selected_dir,
        **data_params,
    )

    # Load sc data
    sc_mix_d, lab_mix_d, sc_sub_dict, sc_sub_dict2 = data_loading.load_sc(
        selected_dir,
        **data_params,
        seed_int=ps_seed,
    )

    target_d = {}
    if "train" in mat_sp_d:
        # keys of dict are splits
        for split in mat_sp_d:
            target_d[split] = np.concatenate(list(mat_sp_d[split].values()))
    else:
        # keys of subdicts are splits
        for split in next(iter(mat_sp_d.values())):
            target_d[split] = np.concatenate((v[split] for v in mat_sp_d.values()))


    advtrain_folder = os.path.join(model_folder, "advtrain")
    pretrain_folder = os.path.join(model_folder, "pretrain")
    if not os.path.isdir(advtrain_folder):
        os.makedirs(advtrain_folder)
    if not os.path.isdir(pretrain_folder):
        os.makedirs(pretrain_folder)

    if data_params.get("samp_split"):
        tqdm.write(f"Adversarial training for slides {mat_sp_d['train'].keys()}: ")
        save_folder = os.path.join(advtrain_folder, "samp_split")
    else:
        tqdm.write(f"Adversarial training for slides {next(iter(mat_sp_d.keys()))}: ")
        save_folder = os.path.join(advtrain_folder, "one_model")

    if not os.path.isdir(save_folder):
        os.makedirs(save_folder)

    embs, embs_noda, clssmodel, clssmodel_noda = da_cellfraction.train(
        sc_mix_d["train"],
        lab_mix_d["train"],
        target_d["train"],
        alpha=train_params.get("alpha", 0.6),
        alpha_lr=train_params.get("alpha_lr", 5),
        emb_dim=model_params["celldart_kwargs"].get("emb_dim", 64),
        batch_size=train_params.get("batch_size", 512),
        n_iterations=train_params.get("n_iter", 3000),
        initial_train=train_params.get("pretraining", True),
        initial_train_epochs=train_params.get("initial_train_epochs", 10),
        batch_size_initial_train=max(train_params.get("batch_size", 512), 512),
        seed=model_seed,
    )

    # Save model

    if not os.path.isdir(os.path.join(save_folder, "final_model")):
        os.makedirs(os.path.join(save_folder, "final_model"))
    if not os.path.isdir(os.path.join(pretrain_folder, "final_model")):
        os.makedirs(os.path.join(pretrain_folder, "final_model"))

    clssmodel_noda.save(os.path.join(pretrain_folder, "final_model", "model"))
    clssmodel.save(os.path.join(save_folder, "final_model", "model"))

    embs_noda.save(os.path.join(pretrain_folder, "final_model", "embs"))
    embs.save(os.path.join(save_folder, "final_model", "embs"))

    with open(os.path.join(model_folder, "config.yml"), "w") as f:
        yaml.safe_dump(config, f)


for ps_seed, model_seed in zip(PS_SEEDS, MODEL_SEEDS):
    train(ps_seed, model_seed)

PS seed: 3679, model seed: 2353
Adversarial training for slides train: 
Train on 100000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
initial_train_done


  updates = self.state_updates


0.021543493438661097
0.021543493438661097
Iteration 99, source loss =  0.755, discriminator acc = 1.000
Iteration 199, source loss =  0.307, discriminator acc = 0.682
Iteration 299, source loss =  1.147, discriminator acc = 0.827
Iteration 399, source loss =  0.222, discriminator acc = 0.832
Iteration 499, source loss =  0.327, discriminator acc = 1.000
Iteration 599, source loss =  0.244, discriminator acc = 1.000
Iteration 699, source loss =  0.121, discriminator acc = 1.000
Iteration 799, source loss =  0.320, discriminator acc = 0.981
Iteration 899, source loss =  0.444, discriminator acc = 0.981
Iteration 999, source loss =  0.223, discriminator acc = 0.984
Iteration 1099, source loss =  0.134, discriminator acc = 0.911
Iteration 1199, source loss =  0.354, discriminator acc = 0.995
Iteration 1299, source loss =  0.193, discriminator acc = 0.139
Iteration 1399, source loss =  0.134, discriminator acc = 0.014
Iteration 1499, source loss =  0.090, discriminator acc = 0.970
Iteration

  updates = self.state_updates


0.02375203073322773
0.02375203073322773
Iteration 99, source loss =  0.889, discriminator acc = 0.006
Iteration 199, source loss =  0.364, discriminator acc = 0.006
Iteration 299, source loss =  0.678, discriminator acc = 0.007
Iteration 399, source loss =  0.564, discriminator acc = 0.006
Iteration 499, source loss =  0.309, discriminator acc = 0.006
Iteration 599, source loss =  0.200, discriminator acc = 0.014
Iteration 699, source loss =  0.240, discriminator acc = 0.793
Iteration 799, source loss =  0.322, discriminator acc = 0.386
Iteration 899, source loss =  0.122, discriminator acc = 0.006
Iteration 999, source loss =  0.151, discriminator acc = 0.006
Iteration 1099, source loss =  0.387, discriminator acc = 0.006
Iteration 1199, source loss =  0.127, discriminator acc = 0.913
Iteration 1299, source loss =  0.161, discriminator acc = 0.006
Iteration 1399, source loss =  0.203, discriminator acc = 0.994
Iteration 1499, source loss =  0.177, discriminator acc = 0.983
Iteration 1

  updates = self.state_updates


0.021420157658457756
0.021420157658457756
Iteration 99, source loss =  0.489, discriminator acc = 0.370
Iteration 199, source loss =  0.652, discriminator acc = 1.000
Iteration 299, source loss =  0.208, discriminator acc = 1.000
Iteration 399, source loss =  0.229, discriminator acc = 1.000
Iteration 499, source loss =  0.224, discriminator acc = 1.000
Iteration 599, source loss =  0.116, discriminator acc = 1.000
Iteration 699, source loss =  0.367, discriminator acc = 1.000
Iteration 799, source loss =  0.123, discriminator acc = 1.000
Iteration 899, source loss =  0.132, discriminator acc = 1.000
Iteration 999, source loss =  0.102, discriminator acc = 0.088
Iteration 1099, source loss =  0.564, discriminator acc = 0.990
Iteration 1199, source loss =  0.137, discriminator acc = 1.000
Iteration 1299, source loss =  0.200, discriminator acc = 0.130
Iteration 1399, source loss =  0.199, discriminator acc = 0.999
Iteration 1499, source loss =  0.187, discriminator acc = 0.999
Iteration

  updates = self.state_updates


0.01992515220373869
0.01992515220373869
Iteration 99, source loss =  1.127, discriminator acc = 0.991
Iteration 199, source loss =  0.994, discriminator acc = 0.880
Iteration 299, source loss =  0.446, discriminator acc = 0.998
Iteration 399, source loss =  0.406, discriminator acc = 1.000
Iteration 499, source loss =  0.393, discriminator acc = 0.015
Iteration 599, source loss =  0.304, discriminator acc = 0.983
Iteration 699, source loss =  0.224, discriminator acc = 1.000
Iteration 799, source loss =  0.307, discriminator acc = 0.930
Iteration 899, source loss =  0.126, discriminator acc = 0.950
Iteration 999, source loss =  0.237, discriminator acc = 0.999
Iteration 1099, source loss =  0.232, discriminator acc = 0.996
Iteration 1199, source loss =  0.101, discriminator acc = 0.979
Iteration 1299, source loss =  0.176, discriminator acc = 0.996
Iteration 1399, source loss =  0.115, discriminator acc = 0.003
Iteration 1499, source loss =  0.076, discriminator acc = 0.311
Iteration 1

  updates = self.state_updates


0.025033984021544456
0.025033984021544456
Iteration 99, source loss =  1.192, discriminator acc = 0.006
Iteration 199, source loss =  0.536, discriminator acc = 0.969
Iteration 299, source loss =  0.326, discriminator acc = 0.007
Iteration 399, source loss =  0.326, discriminator acc = 0.227
Iteration 499, source loss =  0.452, discriminator acc = 0.993
Iteration 599, source loss =  0.219, discriminator acc = 0.131
Iteration 699, source loss =  0.363, discriminator acc = 0.986
Iteration 799, source loss =  0.273, discriminator acc = 0.999
Iteration 899, source loss =  0.317, discriminator acc = 0.972
Iteration 999, source loss =  0.351, discriminator acc = 0.740
Iteration 1099, source loss =  0.110, discriminator acc = 0.999
Iteration 1199, source loss =  0.142, discriminator acc = 0.502
Iteration 1299, source loss =  0.177, discriminator acc = 0.596
Iteration 1399, source loss =  0.088, discriminator acc = 0.998
Iteration 1499, source loss =  0.091, discriminator acc = 0.918
Iteration

## Eval


In [7]:
import argparse
import datetime
import logging

from src.da_models.model_utils.utils import get_metric_ctp
from src.da_utils.evaluator import Evaluator


metric_ctp = get_metric_ctp("cos")


def main(args):
    evaluator = Evaluator(vars(args), metric_ctp)
    evaluator.eval_spots()
    evaluator.evaluate_embeddings()
    evaluator.eval_sc()

    evaluator.produce_results()

for ps_seed, model_seed in zip(PS_SEEDS, MODEL_SEEDS):
    parser = argparse.ArgumentParser(description="Evaluates.")
    parser.add_argument("--pretraining", "-p", action="store_true", help="force pretraining")
    parser.add_argument("--modelname", "-n", type=str, default="ADDA", help="model name")
    parser.add_argument("--milisi", "-m", action="store_false", help="no milisi")
    parser.add_argument("--config_fname", "-f", type=str, help="Name of the config file to use")
    parser.add_argument("--configs_dir", "-cdir", type=str, default="configs", help="config dir")
    parser.add_argument(
        "--njobs", type=int, default=1, help="Number of jobs to use for parallel processing."
    )
    parser.add_argument("--cuda", "-c", default=None, help="GPU index to use")
    parser.add_argument("--tmpdir", "-d", default=None, help="optional temporary results directory")
    parser.add_argument("--test", "-t", action="store_true", help="test mode")
    parser.add_argument(
        "--early_stopping",
        "-e",
        action="store_true",
        help="evaluate early stopping. Default: False",
    )
    parser.add_argument(
        "--reverse_val",
        "-r",
        action="store_true",
        help="use best model through reverse validation. Will use provided"
        "config file to search across models, then use the one loaded. Default: False",
    )
    parser.add_argument("--model_dir", default="model", help="model directory")
    parser.add_argument("--results_dir", default="results", help="results directory")
    parser.add_argument(
        "--seed_override",
        default=None,
        help="seed to use for torch and numpy; overrides that in config file",
    )
    parser.add_argument(
        "--ps_seed",
        default=-1,
        help="specific pseudospot seed to use; default of -1 corresponds to 623",
    )

    args = parser.parse_args([
        f"--modelname={MODEL_NAME}",
        f"--config_fname={CONFIG_FNAME}",
        "--njobs=8",
        "--test",
        f"--model_dir={MODEL_DIR}",
        "--results_dir=results_FINAL",
        f"--seed_override={model_seed}",
        f"--ps_seed={ps_seed}"
    ])

    script_start_time = datetime.datetime.now(datetime.timezone.utc)
    logger = logging.getLogger(__name__)
    logging.basicConfig(
        level=logging.WARNING,
        format="%(asctime)s:%(levelname)s:%(name)s:%(message)s",
    )
    main(args)
    print("Script run time:", datetime.datetime.now(datetime.timezone.utc) - script_start_time)

Evaluating CellDART_original on with 8 jobs
Using library config:
None
Loading config celldart-final-pdac-ht.yml ... 
data_params:
  all_genes: false
  data_dir: ../AGrEDA/data
  dset: pdac
  n_markers: 80
  n_mix: 70
  n_spots: 100000
  one_model: true
  samp_split: false
  sc_id: CA001063
  scaler_name: minmax
  st_id: GSE111672
  st_split: false
lib_params:
  manual_seed: 2123428735
model_params:
  celldart_kwargs:
    bn_momentum: 0.01
    emb_dim: 32
  model_version: gen_pdac-18486
train_params:
  alpha: 0.6
  alpha_lr: 5
  batch_size: 128
  initial_train_epochs: 10
  lr: 0.01
  n_iter: 15000
  pretraining: true
  reverse_val: false

Saving results to results_FINAL/CellDART_original/pdac/CA001063_GSE111672/80markers/70mix_100000spots/minmax/gen_pdac-18486/2353 ...
Loading Data
Loading ST adata: 
Getting predictions: 


  updates=self.state_updates,


Plotting Samples
n_jobs_samples < 4, no parallelization
Calculating domain shift for pdac_a: TRAIN |

  updates=self.state_updates,


 milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for pdac_b: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 


  updates=self.state_updates,
  updates=self.state_updates,


                                          Pseudospots (Cosine Distance)  \
                                                                  train   
                       SC Split Sample ID                                 
Before DA                       pdac_a                         0.009479   
                                pdac_b                         0.009479   
After DA (final model)          pdac_a                         0.036197   
                                pdac_b                         0.036197   

                                                                   RF50  \
                                                val      test     train   
                       SC Split Sample ID                                 
Before DA                       pdac_a     0.010320  0.009933  0.999975   
                                pdac_b     0.010320  0.009933  0.999925   
After DA (final model)          pdac_a     0.036407  0.036144  0.999750   
                        



Loading ST adata: 
Getting predictions: 


  updates=self.state_updates,


Plotting Samples
n_jobs_samples < 4, no parallelization
Calculating domain shift for pdac_a: TRAIN |

  updates=self.state_updates,


 milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for pdac_b: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 


  updates=self.state_updates,
  updates=self.state_updates,


                                          Pseudospots (Cosine Distance)  \
                                                                  train   
                       SC Split Sample ID                                 
Before DA                       pdac_a                         0.011481   
                                pdac_b                         0.011481   
After DA (final model)          pdac_a                         0.080794   
                                pdac_b                         0.080794   

                                                                  RF50  \
                                                val      test    train   
                       SC Split Sample ID                                
Before DA                       pdac_a     0.012181  0.011649  1.00000   
                                pdac_b     0.012181  0.011649  1.00000   
After DA (final model)          pdac_a     0.078774  0.080421  1.00000   
                              



Loading ST adata: 
Getting predictions: 


  updates=self.state_updates,


Plotting Samples
n_jobs_samples < 4, no parallelization
Calculating domain shift for pdac_a: TRAIN |

  updates=self.state_updates,


 milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for pdac_b: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 


  updates=self.state_updates,
  updates=self.state_updates,


                                          Pseudospots (Cosine Distance)  \
                                                                  train   
                       SC Split Sample ID                                 
Before DA                       pdac_a                         0.009191   
                                pdac_b                         0.009191   
After DA (final model)          pdac_a                         0.037574   
                                pdac_b                         0.037574   

                                                                  RF50  \
                                                val      test    train   
                       SC Split Sample ID                                
Before DA                       pdac_a     0.009665  0.009520  0.99975   
                                pdac_b     0.009665  0.009520  0.99995   
After DA (final model)          pdac_a     0.038313  0.037222  1.00000   
                              



Loading ST adata: 
Getting predictions: 


  updates=self.state_updates,


Plotting Samples
n_jobs_samples < 4, no parallelization
Calculating domain shift for pdac_a: TRAIN |

  updates=self.state_updates,


 milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for pdac_b: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 


  updates=self.state_updates,
  updates=self.state_updates,


                                          Pseudospots (Cosine Distance)  \
                                                                  train   
                       SC Split Sample ID                                 
Before DA                       pdac_a                         0.007810   
                                pdac_b                         0.007810   
After DA (final model)          pdac_a                         0.045702   
                                pdac_b                         0.045702   

                                                                   RF50  \
                                                val      test     train   
                       SC Split Sample ID                                 
Before DA                       pdac_a     0.008821  0.008637  1.000000   
                                pdac_b     0.008821  0.008637  1.000000   
After DA (final model)          pdac_a     0.046648  0.046507  0.999975   
                        



Loading ST adata: 
Getting predictions: 


  updates=self.state_updates,


Plotting Samples
n_jobs_samples < 4, no parallelization
Calculating domain shift for pdac_a: TRAIN |

  updates=self.state_updates,


 milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for pdac_b: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 


  updates=self.state_updates,
  updates=self.state_updates,


                                          Pseudospots (Cosine Distance)  \
                                                                  train   
                       SC Split Sample ID                                 
Before DA                       pdac_a                         0.011106   
                                pdac_b                         0.011106   
After DA (final model)          pdac_a                         0.023756   
                                pdac_b                         0.023756   

                                                                   RF50  \
                                                val      test     train   
                       SC Split Sample ID                                 
Before DA                       pdac_a     0.011656  0.011221  0.999975   
                                pdac_b     0.011656  0.011221  1.000000   
After DA (final model)          pdac_a     0.024552  0.023906  0.999875   
                        

