# CellDART Example Code: mouse brain 
## (10x Visium of anterior mouse brain + scRNA-seq data of mouse brain)

In [1]:
import os

import numpy as np
import tensorflow as tf  # TensorFlow registers PluggableDevices here.
from tqdm.autonotebook import tqdm
import yaml

from CellDART import da_cellfraction
from src.da_utils import data_loading




  from tqdm.autonotebook import tqdm


In [2]:
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
        # tf.config.set_logical_device_configuration(
        #     gpu, [tf.config.LogicalDeviceConfiguration(memory_limit=8192)]
        # )

In [3]:
PS_SEEDS = (3679, 343, 25, 234, 98098)
MODEL_SEEDS = (2353, 24385, 284, 86322, 98237)
MODEL_DIR = "model_FINAL"
SEED_OVERRIDE = None

CONFIGS_DIR = "configs"
CONFIG_FNAME = "celldart-final-dlpfc-ht.yml"

# BOOTSTRAP = False
# BOOTSTRAP_ROUNDS = 10
# BOOTSTRAP_ALPHAS = [0.6, 1 / 0.6]

MODEL_NAME = "CellDART_original"


In [4]:
with open(os.path.join(CONFIGS_DIR, MODEL_NAME, CONFIG_FNAME), "r") as f:
    config = yaml.safe_load(f)

lib_params = config["lib_params"]
data_params = config["data_params"]
model_params = config["model_params"]
train_params = config["train_params"]

rewrite_config = False
if not "pretraining" in train_params:
    train_params["pretraining"] = True
    rewrite_config = True
if not "lr" in train_params:
    train_params["lr"] = 0.001
    rewrite_config = True

if rewrite_config:
    with open(os.path.join(CONFIGS_DIR, MODEL_NAME, CONFIG_FNAME), "w") as f:
        yaml.safe_dump(config, f)

tqdm.write(yaml.safe_dump(config))


data_params:
  all_genes: false
  data_dir: ../AGrEDA/data
  dset: dlpfc
  n_markers: 40
  n_mix: 3
  n_spots: 100000
  samp_split: true
  sc_id: GSE144136
  scaler_name: minmax
  st_id: spatialLIBD
  st_split: false
lib_params:
  manual_seed: 1846326316
model_params:
  celldart_kwargs:
    bn_momentum: 0.01
    emb_dim: 32
  model_version: gen_dlpfc_dlpfc-9356
train_params:
  alpha: 2.0
  alpha_lr: 10
  batch_size: 512
  initial_train_epochs: 10
  lr: 0.0001
  n_iter: 15000
  pretraining: true
  reverse_val: false



## 1. Data load
### load scanpy data - 10x datasets

In [5]:
def train(ps_seed, model_seed):
    print(f"PS seed: {ps_seed}, model seed: {model_seed}")

    model_folder = data_loading.get_model_rel_path(
        MODEL_NAME,
        model_params["model_version"],
        lib_seed_path=str(model_seed),
        **data_params,
    )

    model_folder = os.path.join(MODEL_DIR, model_folder)
    if not os.path.isdir(model_folder):
        os.makedirs(model_folder)
        print(model_folder)
        
    selected_dir = data_loading.get_selected_dir(
        data_loading.get_dset_dir(
            data_params["data_dir"],
            dset=data_params.get("dset", "dlpfc"),
        ),
        **data_params,
    )
    # Load spatial data
    mat_sp_d, mat_sp_meta_d, st_sample_id_l = data_loading.load_spatial(
        selected_dir,
        **data_params,
    )

    # Load sc data
    sc_mix_d, lab_mix_d, sc_sub_dict, sc_sub_dict2 = data_loading.load_sc(
        selected_dir,
        **data_params,
        seed_int=ps_seed,
    )

    target_d = {}
    if "train" in mat_sp_d:
        # keys of dict are splits
        for split in mat_sp_d:
            target_d[split] = np.concatenate(list(mat_sp_d[split].values()))
    else:
        # keys of subdicts are splits
        for split in next(iter(mat_sp_d.values())):
            target_d[split] = np.concatenate((v[split] for v in mat_sp_d.values()))


    advtrain_folder = os.path.join(model_folder, "advtrain")
    pretrain_folder = os.path.join(model_folder, "pretrain")
    if not os.path.isdir(advtrain_folder):
        os.makedirs(advtrain_folder)
    if not os.path.isdir(pretrain_folder):
        os.makedirs(pretrain_folder)

    if data_params.get("samp_split"):
        tqdm.write(f"Adversarial training for slides {mat_sp_d['train'].keys()}: ")
        save_folder = os.path.join(advtrain_folder, "samp_split")
    else:
        tqdm.write(f"Adversarial training for slides {next(iter(mat_sp_d.keys()))}: ")
        save_folder = os.path.join(advtrain_folder, "one_model")

    if not os.path.isdir(save_folder):
        os.makedirs(save_folder)

    embs, embs_noda, clssmodel, clssmodel_noda = da_cellfraction.train(
        sc_mix_d["train"],
        lab_mix_d["train"],
        target_d["train"],
        alpha=train_params.get("alpha", 0.6),
        alpha_lr=train_params.get("alpha_lr", 5),
        emb_dim=model_params["celldart_kwargs"].get("emb_dim", 64),
        batch_size=train_params.get("batch_size", 512),
        n_iterations=train_params.get("n_iter", 3000),
        initial_train=train_params.get("pretraining", True),
        initial_train_epochs=train_params.get("initial_train_epochs", 10),
        batch_size_initial_train=max(train_params.get("batch_size", 512), 512),
        seed=model_seed,
    )

    # Save model

    if not os.path.isdir(os.path.join(save_folder, "final_model")):
        os.makedirs(os.path.join(save_folder, "final_model"))
    if not os.path.isdir(os.path.join(pretrain_folder, "final_model")):
        os.makedirs(os.path.join(pretrain_folder, "final_model"))

    clssmodel_noda.save(os.path.join(pretrain_folder, "final_model", "model"))
    clssmodel.save(os.path.join(save_folder, "final_model", "model"))

    embs_noda.save(os.path.join(pretrain_folder, "final_model", "embs"))
    embs.save(os.path.join(save_folder, "final_model", "embs"))

    with open(os.path.join(model_folder, "config.yml"), "w") as f:
        yaml.safe_dump(config, f)


for ps_seed, model_seed in zip(PS_SEEDS, MODEL_SEEDS):
    train(ps_seed, model_seed)

PS seed: 3679, model seed: 2353
Adversarial training for slides dict_keys(['151507', '151508', '151509', '151510', '151669', '151670', '151671', '151672', '151674', '151676']): 
Train on 100000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
initial_train_done


  updates = self.state_updates


0.38906321561336515
0.38906321561336515
Iteration 99, source loss =  3.235, discriminator acc = 0.990
Iteration 199, source loss =  3.526, discriminator acc = 0.759
Iteration 299, source loss =  2.094, discriminator acc = 0.789
Iteration 399, source loss =  1.855, discriminator acc = 0.885
Iteration 499, source loss =  1.622, discriminator acc = 0.911
Iteration 599, source loss =  1.307, discriminator acc = 0.888
Iteration 699, source loss =  1.274, discriminator acc = 0.649
Iteration 799, source loss =  1.203, discriminator acc = 0.923
Iteration 899, source loss =  1.352, discriminator acc = 0.696
Iteration 999, source loss =  1.354, discriminator acc = 0.770
Iteration 1099, source loss =  2.369, discriminator acc = 0.525
Iteration 1199, source loss =  1.651, discriminator acc = 0.743
Iteration 1299, source loss =  2.315, discriminator acc = 0.899
Iteration 1399, source loss =  1.860, discriminator acc = 0.966
Iteration 1499, source loss =  1.828, discriminator acc = 0.964
Iteration 1

  updates = self.state_updates


0.42610923714637755
0.42610923714637755
Iteration 99, source loss =  4.460, discriminator acc = 0.763
Iteration 199, source loss =  3.640, discriminator acc = 0.989
Iteration 299, source loss =  2.021, discriminator acc = 0.971
Iteration 399, source loss =  1.832, discriminator acc = 0.414
Iteration 499, source loss =  1.908, discriminator acc = 0.943
Iteration 599, source loss =  1.840, discriminator acc = 0.690
Iteration 699, source loss =  1.792, discriminator acc = 0.961
Iteration 799, source loss =  1.839, discriminator acc = 0.828
Iteration 899, source loss =  1.806, discriminator acc = 0.914
Iteration 999, source loss =  1.495, discriminator acc = 0.937
Iteration 1099, source loss =  1.408, discriminator acc = 0.978
Iteration 1199, source loss =  1.547, discriminator acc = 0.951
Iteration 1299, source loss =  1.454, discriminator acc = 0.952
Iteration 1399, source loss =  1.459, discriminator acc = 0.438
Iteration 1499, source loss =  1.428, discriminator acc = 0.682
Iteration 1

  updates = self.state_updates


0.4182668998718262
0.4182668998718262
Iteration 99, source loss =  4.476, discriminator acc = 0.904
Iteration 199, source loss =  3.855, discriminator acc = 0.915
Iteration 299, source loss =  3.103, discriminator acc = 0.956
Iteration 399, source loss =  2.381, discriminator acc = 0.971
Iteration 499, source loss =  1.750, discriminator acc = 0.913
Iteration 599, source loss =  1.882, discriminator acc = 0.971
Iteration 699, source loss =  1.830, discriminator acc = 0.598
Iteration 799, source loss =  1.527, discriminator acc = 0.330
Iteration 899, source loss =  1.563, discriminator acc = 0.957
Iteration 999, source loss =  1.634, discriminator acc = 0.841
Iteration 1099, source loss =  1.763, discriminator acc = 0.456
Iteration 1199, source loss =  1.630, discriminator acc = 0.733
Iteration 1299, source loss =  2.009, discriminator acc = 0.988
Iteration 1399, source loss =  2.047, discriminator acc = 0.911
Iteration 1499, source loss =  1.693, discriminator acc = 0.904
Iteration 159

  updates = self.state_updates


0.3818477215671539
0.3818477215671539
Iteration 99, source loss =  3.642, discriminator acc = 0.993
Iteration 199, source loss =  3.271, discriminator acc = 0.741
Iteration 299, source loss =  2.383, discriminator acc = 0.576
Iteration 399, source loss =  2.077, discriminator acc = 0.997
Iteration 499, source loss =  1.914, discriminator acc = 0.492
Iteration 599, source loss =  1.490, discriminator acc = 0.880
Iteration 699, source loss =  1.610, discriminator acc = 0.813
Iteration 799, source loss =  1.528, discriminator acc = 0.621
Iteration 899, source loss =  1.287, discriminator acc = 0.708
Iteration 999, source loss =  1.160, discriminator acc = 0.882
Iteration 1099, source loss =  1.317, discriminator acc = 0.423
Iteration 1199, source loss =  1.913, discriminator acc = 0.773
Iteration 1299, source loss =  2.125, discriminator acc = 0.318
Iteration 1399, source loss =  1.518, discriminator acc = 0.588
Iteration 1499, source loss =  1.465, discriminator acc = 0.601
Iteration 159

  updates = self.state_updates


0.3525105092096329
0.3525105092096329
Iteration 99, source loss =  6.861, discriminator acc = 0.800
Iteration 199, source loss =  2.939, discriminator acc = 0.973
Iteration 299, source loss =  2.650, discriminator acc = 0.370
Iteration 399, source loss =  1.683, discriminator acc = 0.769
Iteration 499, source loss =  1.682, discriminator acc = 0.891
Iteration 599, source loss =  1.721, discriminator acc = 0.806
Iteration 699, source loss =  1.492, discriminator acc = 0.709
Iteration 799, source loss =  1.408, discriminator acc = 0.817
Iteration 899, source loss =  1.509, discriminator acc = 0.336
Iteration 999, source loss =  1.416, discriminator acc = 0.292
Iteration 1099, source loss =  1.432, discriminator acc = 0.758
Iteration 1199, source loss =  1.381, discriminator acc = 0.970
Iteration 1299, source loss =  1.434, discriminator acc = 0.822
Iteration 1399, source loss =  1.694, discriminator acc = 0.983
Iteration 1499, source loss =  1.549, discriminator acc = 0.994
Iteration 159

## Eval


In [6]:
import argparse
import datetime
import logging

from src.da_models.model_utils.utils import get_metric_ctp
from src.da_utils.evaluator import Evaluator


metric_ctp = get_metric_ctp("cos")

def _device_to_str(device):
    return device.name.lstrip("/physical_device:")

def get_device(cuda_index=None):
    if cuda_index is None:
        return _device_to_str(tf.config.list_physical_devices("GPU")[0])

    cuda_index = int(cuda_index)

    if cuda_index < 0:
        return _device_to_str(tf.config.list_physical_devices("CPU")[0])

    devices=tf.config.list_physical_devices("GPU")

    if len(devices) > cuda_index:
        return _device_to_str(devices[cuda_index])
    if len(devices) > 0:
        warnings.warn("GPU ordinal not valid; using default", category=UserWarning, stacklevel=2)
        return _device_to_str(devices[0])

    warnings.warn("Using CPU", category=UserWarning, stacklevel=2)
    return _device_to_str(tf.config.list_physical_devices("CPU")[0])

def main(args):
    evaluator = Evaluator(vars(args), metric_ctp)
    evaluator.eval_spots()
    evaluator.evaluate_embeddings()
    evaluator.eval_sc()

    evaluator.produce_results()

for ps_seed, model_seed in zip(PS_SEEDS, MODEL_SEEDS):

    if model_seed != MODEL_SEEDS[-1]:
        continue
    parser = argparse.ArgumentParser(description="Evaluates.")
    parser.add_argument("--pretraining", "-p", action="store_true", help="force pretraining")
    parser.add_argument("--modelname", "-n", type=str, default="ADDA", help="model name")
    parser.add_argument("--milisi", "-m", action="store_false", help="no milisi")
    parser.add_argument("--config_fname", "-f", type=str, help="Name of the config file to use")
    parser.add_argument("--configs_dir", "-cdir", type=str, default="configs", help="config dir")
    parser.add_argument(
        "--njobs", type=int, default=1, help="Number of jobs to use for parallel processing."
    )
    parser.add_argument("--cuda", "-c", default=None, help="GPU index to use")
    parser.add_argument("--tmpdir", "-d", default=None, help="optional temporary results directory")
    parser.add_argument("--test", "-t", action="store_true", help="test mode")
    parser.add_argument(
        "--early_stopping",
        "-e",
        action="store_true",
        help="evaluate early stopping. Default: False",
    )
    parser.add_argument(
        "--reverse_val",
        "-r",
        action="store_true",
        help="use best model through reverse validation. Will use provided"
        "config file to search across models, then use the one loaded. Default: False",
    )
    parser.add_argument("--model_dir", default="model", help="model directory")
    parser.add_argument("--results_dir", default="results", help="results directory")
    parser.add_argument(
        "--seed_override",
        default=None,
        help="seed to use for torch and numpy; overrides that in config file",
    )
    parser.add_argument(
        "--ps_seed",
        default=-1,
        help="specific pseudospot seed to use; default of -1 corresponds to 623",
    )

    args = parser.parse_args([
        f"--modelname={MODEL_NAME}",
        f"--config_fname={CONFIG_FNAME}",
        "--njobs=1",
        "--test",
        f"--model_dir={MODEL_DIR}",
        "--results_dir=results_FINAL",
        f"--seed_override={model_seed}",
        f"--ps_seed={ps_seed}",
        f"--cuda=0"
    ])

    script_start_time = datetime.datetime.now(datetime.timezone.utc)
    logger = logging.getLogger(__name__)
    logging.basicConfig(
        level=logging.WARNING,
        format="%(asctime)s:%(levelname)s:%(name)s:%(message)s",
    )
    with tf.device(get_device(args.cuda)):
        a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
        b = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
        c = tf.matmul(a, b)

        print(c)

        main(args)
    print("Script run time:", datetime.datetime.now(datetime.timezone.utc) - script_start_time)

Tensor("MatMul_1:0", shape=(2, 2), dtype=float32, device=/device:GPU:0)
Evaluating CellDART_original on with 1 jobs
Using library config:
None
Loading config celldart-final-dlpfc-ht.yml ... 
data_params:
  all_genes: false
  data_dir: ../AGrEDA/data
  dset: dlpfc
  n_markers: 40
  n_mix: 3
  n_spots: 100000
  samp_split: true
  sc_id: GSE144136
  scaler_name: minmax
  st_id: spatialLIBD
  st_split: false
lib_params:
  manual_seed: 1846326316
model_params:
  celldart_kwargs:
    bn_momentum: 0.01
    emb_dim: 32
  model_version: gen_dlpfc_dlpfc-9356
train_params:
  alpha: 2.0
  alpha_lr: 10
  batch_size: 512
  initial_train_epochs: 10
  lr: 0.0001
  n_iter: 15000
  pretraining: true
  reverse_val: false

Saving results to results_FINAL/CellDART_original/dlpfc/GSE144136_spatialLIBD/40markers/3mix_100000spots/minmax/gen_dlpfc_dlpfc-9356/98237 ...
Loading Data


  self.results_folder = self.temp_folder_holder.set_output_folder(


Loading ST adata: 
Getting predictions: 


  updates=self.state_updates,


Plotting Samples
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    6.5s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:   12.9s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:   19.9s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:   26.5s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:   32.4s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   6 out of   6 | elapsed:   38.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   7 out of   7 | elapsed:   44.6s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:   51.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   9 out of   9 | elapsed:   56.9s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  10 out of  10 | elapsed:  1.1min remaining:    0.0s
[Parallel(n_jobs=1)]: Done  11 out of  11 | elapsed:  1.2min remaining:    0.0s
[Parallel(n_jobs=1)]: 

  updates=self.state_updates,


 milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151508: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151509: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151510: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151669: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151670: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151671: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151672: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151674: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151676: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 15167

  updates=self.state_updates,
  updates=self.state_updates,


                                          Pseudospots (Cosine Distance)  \
                                                                  train   
                       SC Split Sample ID                                 
Before DA              train    151507                         0.094176   
                                151508                         0.094176   
                                151509                         0.094176   
                                151510                         0.094176   
                                151669                         0.094176   
                                151670                         0.094176   
                                151671                         0.094176   
                                151672                         0.094176   
                                151674                         0.094176   
                                151676                         0.094176   
                       va

