# CellDART Example Code: mouse brain 
## (10x Visium of anterior mouse brain + scRNA-seq data of mouse brain)

In [1]:
import os

import numpy as np
import tensorflow as tf  # TensorFlow registers PluggableDevices here.
from tqdm.autonotebook import tqdm
import yaml

from CellDART import da_cellfraction
from src.da_utils import data_loading




  from tqdm.autonotebook import tqdm


In [2]:
tf.config.list_physical_devices()


[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [3]:
PS_SEEDS = (3679, 343, 25, 234, 98098)
MODEL_SEEDS = (2353, 24385, 284, 86322, 98237)
MODEL_DIR = "model_FINAL"
SEED_OVERRIDE = None

CONFIGS_DIR = "configs"
CONFIG_FNAME = "celldart-final-pdac-ht.yml"

# BOOTSTRAP = False
# BOOTSTRAP_ROUNDS = 10
# BOOTSTRAP_ALPHAS = [0.6, 1 / 0.6]

MODEL_NAME = "CellDART_original"


In [4]:
with open(os.path.join(CONFIGS_DIR, MODEL_NAME, CONFIG_FNAME), "r") as f:
    config = yaml.safe_load(f)

lib_params = config["lib_params"]
data_params = config["data_params"]
model_params = config["model_params"]
train_params = config["train_params"]

rewrite_config = False
if not "pretraining" in train_params:
    train_params["pretraining"] = True
    rewrite_config = True
if not "lr" in train_params:
    train_params["lr"] = 0.001
    rewrite_config = True

if rewrite_config:
    with open(os.path.join(CONFIGS_DIR, MODEL_NAME, CONFIG_FNAME), "w") as f:
        yaml.safe_dump(config, f)

tqdm.write(yaml.safe_dump(config))


data_params:
  all_genes: false
  data_dir: ../AGrEDA/data
  dset: pdac
  n_markers: 80
  n_mix: 50
  n_spots: 100000
  one_model: true
  samp_split: false
  sc_id: CA001063
  scaler_name: minmax
  st_id: GSE111672
  st_split: false
lib_params:
  manual_seed: 3195139925
model_params:
  celldart_kwargs:
    bn_momentum: 0.9
    emb_dim: 32
  model_version: gen_pdac-16798
train_params:
  alpha: 1.0
  alpha_lr: 5
  batch_size: 256
  initial_train_epochs: 10
  lr: 0.001
  n_iter: 15000
  pretraining: true
  reverse_val: false



## 1. Data load
### load scanpy data - 10x datasets

In [5]:
def train(ps_seed, model_seed):
    print(f"PS seed: {ps_seed}, model seed: {model_seed}")

    model_folder = data_loading.get_model_rel_path(
        MODEL_NAME,
        model_params["model_version"],
        lib_seed_path=str(model_seed),
        **data_params,
    )

    model_folder = os.path.join(MODEL_DIR, model_folder)
    if not os.path.isdir(model_folder):
        os.makedirs(model_folder)
        print(model_folder)
        
    selected_dir = data_loading.get_selected_dir(
        data_loading.get_dset_dir(
            data_params["data_dir"],
            dset=data_params.get("dset", "dlpfc"),
        ),
        **data_params,
    )
    # Load spatial data
    mat_sp_d, mat_sp_meta_d, st_sample_id_l = data_loading.load_spatial(
        selected_dir,
        **data_params,
    )

    # Load sc data
    sc_mix_d, lab_mix_d, sc_sub_dict, sc_sub_dict2 = data_loading.load_sc(
        selected_dir,
        **data_params,
        seed_int=ps_seed,
    )

    target_d = {}
    if "train" in mat_sp_d:
        # keys of dict are splits
        for split in mat_sp_d:
            target_d[split] = np.concatenate(list(mat_sp_d[split].values()))
    else:
        # keys of subdicts are splits
        for split in next(iter(mat_sp_d.values())):
            target_d[split] = np.concatenate((v[split] for v in mat_sp_d.values()))


    advtrain_folder = os.path.join(model_folder, "advtrain")
    pretrain_folder = os.path.join(model_folder, "pretrain")
    if not os.path.isdir(advtrain_folder):
        os.makedirs(advtrain_folder)
    if not os.path.isdir(pretrain_folder):
        os.makedirs(pretrain_folder)

    if data_params.get("samp_split"):
        tqdm.write(f"Adversarial training for slides {mat_sp_d['train'].keys()}: ")
        save_folder = os.path.join(advtrain_folder, "samp_split")
    else:
        tqdm.write(f"Adversarial training for slides {next(iter(mat_sp_d.keys()))}: ")
        save_folder = os.path.join(advtrain_folder, "one_model")

    if not os.path.isdir(save_folder):
        os.makedirs(save_folder)

    embs, embs_noda, clssmodel, clssmodel_noda = da_cellfraction.train(
        sc_mix_d["train"],
        lab_mix_d["train"],
        target_d["train"],
        alpha=train_params.get("alpha", 0.6),
        alpha_lr=train_params.get("alpha_lr", 5),
        emb_dim=model_params["celldart_kwargs"].get("emb_dim", 64),
        batch_size=train_params.get("batch_size", 512),
        n_iterations=train_params.get("n_iter", 3000),
        initial_train=train_params.get("pretraining", True),
        initial_train_epochs=train_params.get("initial_train_epochs", 10),
        batch_size_initial_train=max(train_params.get("batch_size", 512), 512),
        bn_momentum=1-config["model_params"].get("bn_momentum", 0.01),
        seed=model_seed,
    )

    # Save model

    if not os.path.isdir(os.path.join(save_folder, "final_model")):
        os.makedirs(os.path.join(save_folder, "final_model"))
    if not os.path.isdir(os.path.join(pretrain_folder, "final_model")):
        os.makedirs(os.path.join(pretrain_folder, "final_model"))

    clssmodel_noda.save(os.path.join(pretrain_folder, "final_model", "model"))
    clssmodel.save(os.path.join(save_folder, "final_model", "model"))

    embs_noda.save(os.path.join(pretrain_folder, "final_model", "embs"))
    embs.save(os.path.join(save_folder, "final_model", "embs"))

    with open(os.path.join(model_folder, "config.yml"), "w") as f:
        yaml.safe_dump(config, f)


for ps_seed, model_seed in zip(PS_SEEDS, MODEL_SEEDS):
    train(ps_seed, model_seed)

PS seed: 3679, model seed: 2353
Adversarial training for slides train: 


Train on 100000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
initial_train_done


  updates = self.state_updates


0.0303476790869236
0.0303476790869236
Iteration 99, source loss =  1.851, discriminator acc = 0.998
Iteration 199, source loss =  0.543, discriminator acc = 0.986
Iteration 299, source loss =  0.276, discriminator acc = 1.000
Iteration 399, source loss =  0.293, discriminator acc = 0.999
Iteration 499, source loss =  0.443, discriminator acc = 1.000
Iteration 599, source loss =  0.215, discriminator acc = 0.965
Iteration 699, source loss =  0.287, discriminator acc = 1.000
Iteration 799, source loss =  0.322, discriminator acc = 0.411
Iteration 899, source loss =  0.389, discriminator acc = 0.667
Iteration 999, source loss =  0.122, discriminator acc = 0.957
Iteration 1099, source loss =  0.199, discriminator acc = 1.000
Iteration 1199, source loss =  0.157, discriminator acc = 0.924
Iteration 1299, source loss =  0.217, discriminator acc = 0.394
Iteration 1399, source loss =  0.246, discriminator acc = 0.006
Iteration 1499, source loss =  0.213, discriminator acc = 1.000
Iteration 159

  updates = self.state_updates


0.027807374956011773
0.027807374956011773
Iteration 99, source loss =  1.982, discriminator acc = 0.006
Iteration 199, source loss =  0.787, discriminator acc = 0.006
Iteration 299, source loss =  0.422, discriminator acc = 0.006
Iteration 399, source loss =  0.221, discriminator acc = 0.010
Iteration 499, source loss =  0.220, discriminator acc = 1.000
Iteration 599, source loss =  0.248, discriminator acc = 0.905
Iteration 699, source loss =  0.222, discriminator acc = 0.221
Iteration 799, source loss =  0.464, discriminator acc = 0.009
Iteration 899, source loss =  0.166, discriminator acc = 0.497
Iteration 999, source loss =  0.288, discriminator acc = 0.346
Iteration 1099, source loss =  0.148, discriminator acc = 0.006
Iteration 1199, source loss =  0.142, discriminator acc = 0.996
Iteration 1299, source loss =  0.233, discriminator acc = 0.022
Iteration 1399, source loss =  0.137, discriminator acc = 0.984
Iteration 1499, source loss =  0.182, discriminator acc = 0.463
Iteration

  updates = self.state_updates


0.02885524251639843
0.02885524251639843
Iteration 99, source loss =  2.198, discriminator acc = 0.006
Iteration 199, source loss =  0.865, discriminator acc = 0.678
Iteration 299, source loss =  0.456, discriminator acc = 1.000
Iteration 399, source loss =  0.528, discriminator acc = 0.006
Iteration 499, source loss =  0.214, discriminator acc = 0.006
Iteration 599, source loss =  0.136, discriminator acc = 0.006
Iteration 699, source loss =  0.216, discriminator acc = 0.006
Iteration 799, source loss =  0.154, discriminator acc = 0.006
Iteration 899, source loss =  0.227, discriminator acc = 0.006
Iteration 999, source loss =  0.190, discriminator acc = 0.006
Iteration 1099, source loss =  0.189, discriminator acc = 0.006
Iteration 1199, source loss =  0.133, discriminator acc = 0.006
Iteration 1299, source loss =  0.202, discriminator acc = 0.006
Iteration 1399, source loss =  0.124, discriminator acc = 0.007
Iteration 1499, source loss =  0.235, discriminator acc = 0.966
Iteration 1

  updates = self.state_updates


0.02694155962407589
0.02694155962407589
Iteration 99, source loss =  1.267, discriminator acc = 0.006
Iteration 199, source loss =  0.442, discriminator acc = 0.501
Iteration 299, source loss =  0.399, discriminator acc = 1.000
Iteration 399, source loss =  0.310, discriminator acc = 0.587
Iteration 499, source loss =  0.503, discriminator acc = 1.000
Iteration 599, source loss =  0.288, discriminator acc = 0.006
Iteration 699, source loss =  0.197, discriminator acc = 0.006
Iteration 799, source loss =  0.231, discriminator acc = 0.999
Iteration 899, source loss =  0.200, discriminator acc = 0.672
Iteration 999, source loss =  0.159, discriminator acc = 0.185
Iteration 1099, source loss =  0.196, discriminator acc = 0.006
Iteration 1199, source loss =  0.158, discriminator acc = 0.998
Iteration 1299, source loss =  0.126, discriminator acc = 0.008
Iteration 1399, source loss =  0.133, discriminator acc = 0.097
Iteration 1499, source loss =  0.164, discriminator acc = 0.064
Iteration 1

  updates = self.state_updates


0.02675807018876076
0.02675807018876076
Iteration 99, source loss =  1.344, discriminator acc = 0.006
Iteration 199, source loss =  0.348, discriminator acc = 0.945
Iteration 299, source loss =  0.523, discriminator acc = 1.000
Iteration 399, source loss =  0.198, discriminator acc = 0.990
Iteration 499, source loss =  0.615, discriminator acc = 1.000
Iteration 599, source loss =  0.223, discriminator acc = 1.000
Iteration 699, source loss =  0.279, discriminator acc = 0.053
Iteration 799, source loss =  0.205, discriminator acc = 0.103
Iteration 899, source loss =  0.152, discriminator acc = 0.006
Iteration 999, source loss =  0.167, discriminator acc = 0.871
Iteration 1099, source loss =  0.153, discriminator acc = 0.156
Iteration 1199, source loss =  0.212, discriminator acc = 0.355
Iteration 1299, source loss =  0.266, discriminator acc = 0.006
Iteration 1399, source loss =  0.188, discriminator acc = 1.000
Iteration 1499, source loss =  0.154, discriminator acc = 0.214
Iteration 1

## Eval


In [6]:
import argparse
import datetime
import logging

from src.da_models.model_utils.utils import get_metric_ctp
from src.da_utils.evaluator import Evaluator


metric_ctp = get_metric_ctp("cos")


def main(args):
    evaluator = Evaluator(vars(args), metric_ctp)
    evaluator.eval_spots()
    evaluator.evaluate_embeddings()
    evaluator.eval_sc()

    evaluator.produce_results()

for ps_seed, model_seed in zip(PS_SEEDS, MODEL_SEEDS):
    if MODEL_SEED != 
    parser = argparse.ArgumentParser(description="Evaluates.")
    parser.add_argument("--pretraining", "-p", action="store_true", help="force pretraining")
    parser.add_argument("--modelname", "-n", type=str, default="ADDA", help="model name")
    parser.add_argument("--milisi", "-m", action="store_false", help="no milisi")
    parser.add_argument("--config_fname", "-f", type=str, help="Name of the config file to use")
    parser.add_argument("--configs_dir", "-cdir", type=str, default="configs", help="config dir")
    parser.add_argument(
        "--njobs", type=int, default=1, help="Number of jobs to use for parallel processing."
    )
    parser.add_argument("--cuda", "-c", default=None, help="GPU index to use")
    parser.add_argument("--tmpdir", "-d", default=None, help="optional temporary results directory")
    parser.add_argument("--test", "-t", action="store_true", help="test mode")
    parser.add_argument(
        "--early_stopping",
        "-e",
        action="store_true",
        help="evaluate early stopping. Default: False",
    )
    parser.add_argument(
        "--reverse_val",
        "-r",
        action="store_true",
        help="use best model through reverse validation. Will use provided"
        "config file to search across models, then use the one loaded. Default: False",
    )
    parser.add_argument("--model_dir", default="model", help="model directory")
    parser.add_argument("--results_dir", default="results", help="results directory")
    parser.add_argument(
        "--seed_override",
        default=None,
        help="seed to use for torch and numpy; overrides that in config file",
    )
    parser.add_argument(
        "--ps_seed",
        default=-1,
        help="specific pseudospot seed to use; default of -1 corresponds to 623",
    )

    args = parser.parse_args([
        f"--modelname={MODEL_NAME}",
        f"--config_fname={CONFIG_FNAME}",
        "--njobs=8",
        "--test",
        f"--model_dir={MODEL_DIR}",
        "--results_dir=results_FINAL",
        f"--seed_override={model_seed}",
        f"--ps_seed={ps_seed}"
    ])

    script_start_time = datetime.datetime.now(datetime.timezone.utc)
    logger = logging.getLogger(__name__)
    logging.basicConfig(
        level=logging.WARNING,
        format="%(asctime)s:%(levelname)s:%(name)s:%(message)s",
    )
    main(args)
    print("Script run time:", datetime.datetime.now(datetime.timezone.utc) - script_start_time)

Evaluating CellDART_original on with 8 jobs
Using library config:
None
Loading config celldart-final-pdac-ht.yml ... 
data_params:
  all_genes: false
  data_dir: ../AGrEDA/data
  dset: pdac
  n_markers: 80
  n_mix: 50
  n_spots: 100000
  one_model: true
  samp_split: false
  sc_id: CA001063
  scaler_name: minmax
  st_id: GSE111672
  st_split: false
lib_params:
  manual_seed: 3195139925
model_params:
  celldart_kwargs:
    bn_momentum: 0.9
    emb_dim: 32
  model_version: gen_pdac-16798
train_params:
  alpha: 1.0
  alpha_lr: 5
  batch_size: 256
  initial_train_epochs: 10
  lr: 0.001
  n_iter: 15000
  pretraining: true
  reverse_val: false

Saving results to results_FINAL/CellDART_original/pdac/CA001063_GSE111672/80markers/50mix_100000spots/minmax/gen_pdac-16798/2353 ...
Loading Data
Loading ST adata: 
Getting predictions: 


  updates=self.state_updates,


Plotting Samples
n_jobs_samples < 4, no parallelization
Calculating domain shift for pdac_a: TRAIN |

  updates=self.state_updates,


 milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for pdac_b: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 


  updates=self.state_updates,
  updates=self.state_updates,


                                          Pseudospots (Cosine Distance)  \
                                                                  train   
                       SC Split Sample ID                                 
Before DA                       pdac_a                         0.012334   
                                pdac_b                         0.012334   
After DA (final model)          pdac_a                         0.077407   
                                pdac_b                         0.077407   

                                                                  RF50  \
                                                val      test    train   
                       SC Split Sample ID                                
Before DA                       pdac_a     0.013174  0.013555  0.99990   
                                pdac_b     0.013174  0.013555  1.00000   
After DA (final model)          pdac_a     0.078702  0.075747  0.99995   
                              



Loading ST adata: 
Getting predictions: 


  updates=self.state_updates,


Plotting Samples
n_jobs_samples < 4, no parallelization
Calculating domain shift for pdac_a: TRAIN |

  updates=self.state_updates,


 milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for pdac_b: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 


  updates=self.state_updates,
  updates=self.state_updates,


                                          Pseudospots (Cosine Distance)  \
                                                                  train   
                       SC Split Sample ID                                 
Before DA                       pdac_a                         0.010409   
                                pdac_b                         0.010409   
After DA (final model)          pdac_a                         0.066852   
                                pdac_b                         0.066852   

                                                                   RF50  \
                                                val      test     train   
                       SC Split Sample ID                                 
Before DA                       pdac_a     0.011425  0.011368  1.000000   
                                pdac_b     0.011425  0.011368  1.000000   
After DA (final model)          pdac_a     0.066778  0.067337  0.999925   
                        



Loading ST adata: 
Getting predictions: 


  updates=self.state_updates,


Plotting Samples
n_jobs_samples < 4, no parallelization
Calculating domain shift for pdac_a: TRAIN |

  updates=self.state_updates,


 milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for pdac_b: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 


  updates=self.state_updates,
  updates=self.state_updates,


                                          Pseudospots (Cosine Distance)  \
                                                                  train   
                       SC Split Sample ID                                 
Before DA                       pdac_a                         0.009832   
                                pdac_b                         0.009832   
After DA (final model)          pdac_a                         0.044209   
                                pdac_b                         0.044209   

                                                               RF50       \
                                                val      test train  val   
                       SC Split Sample ID                                  
Before DA                       pdac_a     0.010825  0.010928   1.0  1.0   
                                pdac_b     0.010825  0.010928   1.0  1.0   
After DA (final model)          pdac_a     0.044706  0.044132   1.0  1.0   
                  



Loading ST adata: 
Getting predictions: 


  updates=self.state_updates,


Plotting Samples
n_jobs_samples < 4, no parallelization
Calculating domain shift for pdac_a: TRAIN |

  updates=self.state_updates,


 milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for pdac_b: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 


  updates=self.state_updates,
  updates=self.state_updates,


                                          Pseudospots (Cosine Distance)  \
                                                                  train   
                       SC Split Sample ID                                 
Before DA                       pdac_a                         0.010234   
                                pdac_b                         0.010234   
After DA (final model)          pdac_a                         0.065965   
                                pdac_b                         0.065965   

                                                                   RF50       \
                                                val      test     train  val   
                       SC Split Sample ID                                      
Before DA                       pdac_a     0.011132  0.011099  1.000000  1.0   
                                pdac_b     0.011132  0.011099  1.000000  1.0   
After DA (final model)          pdac_a     0.067465  0.064682  0.999975  1



Loading ST adata: 
Getting predictions: 


  updates=self.state_updates,


Plotting Samples
n_jobs_samples < 4, no parallelization
Calculating domain shift for pdac_a: TRAIN |

  updates=self.state_updates,


 milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for pdac_b: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 


  updates=self.state_updates,
  updates=self.state_updates,


                                          Pseudospots (Cosine Distance)  \
                                                                  train   
                       SC Split Sample ID                                 
Before DA                       pdac_a                         0.009917   
                                pdac_b                         0.009917   
After DA (final model)          pdac_a                         0.052353   
                                pdac_b                         0.052353   

                                                                   RF50  \
                                                val      test     train   
                       SC Split Sample ID                                 
Before DA                       pdac_a     0.011289  0.011472  0.999975   
                                pdac_b     0.011289  0.011472  0.999975   
After DA (final model)          pdac_a     0.052857  0.052771  0.999225   
                        



: 