# CellDART Example Code: mouse brain 
## (10x Visium of anterior mouse brain + scRNA-seq data of mouse brain)

In [1]:
import os
import warnings

import numpy as np
import tensorflow as tf  # TensorFlow registers PluggableDevices here.
from tqdm.autonotebook import tqdm
import yaml

from CellDART import da_cellfraction
from src.da_utils import data_loading




  from tqdm.autonotebook import tqdm


In [2]:
tf.config.list_physical_devices()


[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [3]:
PS_SEEDS = (3679, 343, 25, 234, 98098)
MODEL_SEEDS = (2353, 24385, 284, 86322, 98237)
MODEL_DIR = "model_FINAL"
SEED_OVERRIDE = None

CONFIGS_DIR = "configs"
CONFIG_FNAME = "basic_config.yml"

# BOOTSTRAP = False
# BOOTSTRAP_ROUNDS = 10
# BOOTSTRAP_ALPHAS = [0.6, 1 / 0.6]

MODEL_NAME = "CellDART_original"


In [4]:
with open(os.path.join(CONFIGS_DIR, MODEL_NAME, CONFIG_FNAME), "r") as f:
    config = yaml.safe_load(f)

lib_params = config["lib_params"]
data_params = config["data_params"]
model_params = config["model_params"]
train_params = config["train_params"]

rewrite_config = False
if not "pretraining" in train_params:
    train_params["pretraining"] = True
    rewrite_config = True
if not "lr" in train_params:
    train_params["lr"] = 0.001
    rewrite_config = True

if rewrite_config:
    with open(os.path.join(CONFIGS_DIR, MODEL_NAME, CONFIG_FNAME), "w") as f:
        yaml.safe_dump(config, f)

tqdm.write(yaml.safe_dump(config))


data_params:
  all_genes: false
  data_dir: ../AGrEDA/data
  dset: dlpfc
  n_markers: 20
  n_mix: 8
  n_spots: 100000
  samp_split: true
  sc_id: GSE144136
  scaler_name: minmax
  st_id: spatialLIBD
  st_split: false
lib_params: {}
model_params:
  celldart_kwargs:
    bn_momentum: 0.01
    emb_dim: 64
  model_version: std
train_params:
  alpha: 0.6
  alpha_lr: 5
  batch_size: 512
  initial_train_epochs: 10
  lr: 0.001
  n_iter: 15000
  pretraining: true
  reverse_val: true



## 1. Data load
### load scanpy data - 10x datasets

In [5]:
def train(ps_seed, model_seed):
    print(f"PS seed: {ps_seed}, model seed: {model_seed}")

    model_folder = data_loading.get_model_rel_path(
        MODEL_NAME,
        model_params["model_version"],
        lib_seed_path=str(model_seed),
        **data_params,
    )

    model_folder = os.path.join(MODEL_DIR, model_folder)
    if not os.path.isdir(model_folder):
        os.makedirs(model_folder)
        print(model_folder)
        
    selected_dir = data_loading.get_selected_dir(
        data_loading.get_dset_dir(
            data_params["data_dir"],
            dset=data_params.get("dset", "dlpfc"),
        ),
        **data_params,
    )
    # Load spatial data
    mat_sp_d, mat_sp_meta_d, st_sample_id_l = data_loading.load_spatial(
        selected_dir,
        **data_params,
    )

    # Load sc data
    sc_mix_d, lab_mix_d, sc_sub_dict, sc_sub_dict2 = data_loading.load_sc(
        selected_dir,
        **data_params,
        seed_int=ps_seed,
    )

    target_d = {}
    if "train" in mat_sp_d:
        # keys of dict are splits
        for split in mat_sp_d:
            target_d[split] = np.concatenate(list(mat_sp_d[split].values()))
    else:
        # keys of subdicts are splits
        for split in next(iter(mat_sp_d.values())):
            target_d[split] = np.concatenate((v[split] for v in mat_sp_d.values()))


    advtrain_folder = os.path.join(model_folder, "advtrain")
    pretrain_folder = os.path.join(model_folder, "pretrain")
    if not os.path.isdir(advtrain_folder):
        os.makedirs(advtrain_folder)
    if not os.path.isdir(pretrain_folder):
        os.makedirs(pretrain_folder)

    if data_params.get("samp_split"):
        tqdm.write(f"Adversarial training for slides {mat_sp_d['train'].keys()}: ")
        save_folder = os.path.join(advtrain_folder, "samp_split")
    else:
        tqdm.write(f"Adversarial training for slides {next(iter(mat_sp_d.keys()))}: ")
        save_folder = os.path.join(advtrain_folder, "one_model")

    if not os.path.isdir(save_folder):
        os.makedirs(save_folder)

    embs, embs_noda, clssmodel, clssmodel_noda = da_cellfraction.train(
        sc_mix_d["train"],
        lab_mix_d["train"],
        target_d["train"],
        alpha=train_params.get("alpha", 0.6),
        alpha_lr=train_params.get("alpha_lr", 5),
        emb_dim=model_params["celldart_kwargs"].get("emb_dim", 64),
        batch_size=train_params.get("batch_size", 512),
        n_iterations=train_params.get("n_iter", 3000),
        initial_train=train_params.get("pretraining", True),
        initial_train_epochs=train_params.get("initial_train_epochs", 10),
        batch_size_initial_train=max(train_params.get("batch_size", 512), 512),
        seed=model_seed,
    )

    # Save model

    if not os.path.isdir(os.path.join(save_folder, "final_model")):
        os.makedirs(os.path.join(save_folder, "final_model"))
    if not os.path.isdir(os.path.join(pretrain_folder, "final_model")):
        os.makedirs(os.path.join(pretrain_folder, "final_model"))

    clssmodel_noda.save(os.path.join(pretrain_folder, "final_model", "model"))
    clssmodel.save(os.path.join(save_folder, "final_model", "model"))

    embs_noda.save(os.path.join(pretrain_folder, "final_model", "embs"))
    embs.save(os.path.join(save_folder, "final_model", "embs"))

    with open(os.path.join(model_folder, "config.yml"), "w") as f:
        yaml.safe_dump(config, f)


for ps_seed, model_seed in zip(PS_SEEDS, MODEL_SEEDS):
    train(ps_seed, model_seed)

PS seed: 3679, model seed: 2353
Adversarial training for slides dict_keys(['151507', '151508', '151509', '151510', '151669', '151670', '151671', '151672', '151674', '151676']): 
Train on 100000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
initial_train_done


  updates = self.state_updates


0.5752687343978882
0.5752687343978882
Iteration 99, source loss =  4.386, discriminator acc = 0.288
Iteration 199, source loss =  2.327, discriminator acc = 0.135
Iteration 299, source loss =  1.637, discriminator acc = 0.574
Iteration 399, source loss =  1.921, discriminator acc = 0.682
Iteration 499, source loss =  1.555, discriminator acc = 0.774
Iteration 599, source loss =  1.310, discriminator acc = 0.836
Iteration 699, source loss =  1.287, discriminator acc = 0.974
Iteration 799, source loss =  1.142, discriminator acc = 0.966
Iteration 899, source loss =  1.268, discriminator acc = 0.475
Iteration 999, source loss =  0.970, discriminator acc = 0.790
Iteration 1099, source loss =  1.010, discriminator acc = 0.481
Iteration 1199, source loss =  0.866, discriminator acc = 0.865
Iteration 1299, source loss =  0.760, discriminator acc = 0.942
Iteration 1399, source loss =  0.879, discriminator acc = 0.984
Iteration 1499, source loss =  0.730, discriminator acc = 0.713
Iteration 159

  updates = self.state_updates


0.5236840969657898
0.5236840969657898
Iteration 99, source loss =  4.612, discriminator acc = 0.223
Iteration 199, source loss =  1.603, discriminator acc = 0.752
Iteration 299, source loss =  1.737, discriminator acc = 0.288
Iteration 399, source loss =  2.200, discriminator acc = 0.937
Iteration 499, source loss =  1.384, discriminator acc = 0.846
Iteration 599, source loss =  1.281, discriminator acc = 0.839
Iteration 699, source loss =  1.180, discriminator acc = 0.930
Iteration 799, source loss =  1.192, discriminator acc = 0.929
Iteration 899, source loss =  1.206, discriminator acc = 0.728
Iteration 999, source loss =  1.164, discriminator acc = 0.697
Iteration 1099, source loss =  0.904, discriminator acc = 0.961
Iteration 1199, source loss =  1.056, discriminator acc = 0.927
Iteration 1299, source loss =  0.985, discriminator acc = 0.684
Iteration 1399, source loss =  0.940, discriminator acc = 0.527
Iteration 1499, source loss =  0.843, discriminator acc = 0.640
Iteration 159

  updates = self.state_updates


0.49521382831573485
0.49521382831573485
Iteration 99, source loss =  4.447, discriminator acc = 0.723
Iteration 199, source loss =  1.471, discriminator acc = 0.974
Iteration 299, source loss =  2.215, discriminator acc = 0.287
Iteration 399, source loss =  1.800, discriminator acc = 0.874
Iteration 499, source loss =  1.146, discriminator acc = 0.995
Iteration 599, source loss =  1.265, discriminator acc = 0.855
Iteration 699, source loss =  1.251, discriminator acc = 0.737
Iteration 799, source loss =  1.136, discriminator acc = 0.869
Iteration 899, source loss =  1.065, discriminator acc = 0.479
Iteration 999, source loss =  0.962, discriminator acc = 0.348
Iteration 1099, source loss =  0.804, discriminator acc = 0.907
Iteration 1199, source loss =  0.881, discriminator acc = 0.646
Iteration 1299, source loss =  0.809, discriminator acc = 0.128
Iteration 1399, source loss =  0.686, discriminator acc = 0.867
Iteration 1499, source loss =  0.697, discriminator acc = 0.990
Iteration 1

  updates = self.state_updates


0.4657786366939545
0.4657786366939545
Iteration 99, source loss =  3.631, discriminator acc = 0.605
Iteration 199, source loss =  1.949, discriminator acc = 0.553
Iteration 299, source loss =  2.359, discriminator acc = 0.288
Iteration 399, source loss =  2.323, discriminator acc = 0.620
Iteration 499, source loss =  1.330, discriminator acc = 0.866
Iteration 599, source loss =  1.568, discriminator acc = 0.704
Iteration 699, source loss =  1.171, discriminator acc = 0.632
Iteration 799, source loss =  1.328, discriminator acc = 0.816
Iteration 899, source loss =  1.233, discriminator acc = 0.925
Iteration 999, source loss =  1.086, discriminator acc = 0.697
Iteration 1099, source loss =  0.970, discriminator acc = 0.820
Iteration 1199, source loss =  0.990, discriminator acc = 0.860
Iteration 1299, source loss =  0.865, discriminator acc = 0.503
Iteration 1399, source loss =  0.712, discriminator acc = 0.951
Iteration 1499, source loss =  0.740, discriminator acc = 0.837
Iteration 159

  updates = self.state_updates


0.5068058612442017
0.5068058612442017
Iteration 99, source loss =  2.788, discriminator acc = 0.292
Iteration 199, source loss =  1.985, discriminator acc = 0.446
Iteration 299, source loss =  1.651, discriminator acc = 0.738
Iteration 399, source loss =  1.280, discriminator acc = 0.713
Iteration 499, source loss =  1.533, discriminator acc = 0.988
Iteration 599, source loss =  1.738, discriminator acc = 0.597
Iteration 699, source loss =  1.414, discriminator acc = 0.983
Iteration 799, source loss =  1.082, discriminator acc = 0.993
Iteration 899, source loss =  1.010, discriminator acc = 0.582
Iteration 999, source loss =  0.892, discriminator acc = 0.977
Iteration 1099, source loss =  0.818, discriminator acc = 0.996
Iteration 1199, source loss =  0.826, discriminator acc = 0.074
Iteration 1299, source loss =  0.701, discriminator acc = 0.939
Iteration 1399, source loss =  0.710, discriminator acc = 0.977
Iteration 1499, source loss =  0.749, discriminator acc = 0.889
Iteration 159

## Eval


In [5]:
import argparse
import datetime
import logging

from src.da_models.model_utils.utils import get_metric_ctp
from src.da_utils.evaluator import Evaluator


metric_ctp = get_metric_ctp("cos")

def _device_to_str(device):
    return device.name.lstrip("/physical_device:")

def get_device(cuda_index=None):
    if cuda_index is None:
        return _device_to_str(tf.config.list_physical_devices("GPU")[0])

    cuda_index = int(cuda_index)

    if cuda_index < 0:
        return _device_to_str(tf.config.list_physical_devices("CPU")[0])

    devices=tf.config.list_physical_devices("GPU")

    if len(devices) > cuda_index:
        return _device_to_str(devices[cuda_index])
    if len(devices) > 0:
        warnings.warn("GPU ordinal not valid; using default", category=UserWarning, stacklevel=2)
        return _device_to_str(devices[0])

    warnings.warn("Using CPU", category=UserWarning, stacklevel=2)
    return _device_to_str(tf.config.list_physical_devices("CPU")[0])

def main(args):
    evaluator = Evaluator(vars(args), metric_ctp)
    evaluator.eval_spots()
    evaluator.evaluate_embeddings()
    evaluator.eval_sc()

    evaluator.produce_results()

for ps_seed, model_seed in zip(PS_SEEDS, MODEL_SEEDS):
    parser = argparse.ArgumentParser(description="Evaluates.")
    parser.add_argument("--pretraining", "-p", action="store_true", help="force pretraining")
    parser.add_argument("--modelname", "-n", type=str, default="ADDA", help="model name")
    parser.add_argument("--milisi", "-m", action="store_false", help="no milisi")
    parser.add_argument("--config_fname", "-f", type=str, help="Name of the config file to use")
    parser.add_argument("--configs_dir", "-cdir", type=str, default="configs", help="config dir")
    parser.add_argument(
        "--njobs", type=int, default=1, help="Number of jobs to use for parallel processing."
    )
    parser.add_argument("--cuda", "-c", default=None, help="GPU index to use")
    parser.add_argument("--tmpdir", "-d", default=None, help="optional temporary results directory")
    parser.add_argument("--test", "-t", action="store_true", help="test mode")
    parser.add_argument(
        "--early_stopping",
        "-e",
        action="store_true",
        help="evaluate early stopping. Default: False",
    )
    parser.add_argument(
        "--reverse_val",
        "-r",
        action="store_true",
        help="use best model through reverse validation. Will use provided"
        "config file to search across models, then use the one loaded. Default: False",
    )
    parser.add_argument("--model_dir", default="model", help="model directory")
    parser.add_argument("--results_dir", default="results", help="results directory")
    parser.add_argument(
        "--seed_override",
        default=None,
        help="seed to use for torch and numpy; overrides that in config file",
    )
    parser.add_argument(
        "--ps_seed",
        default=-1,
        help="specific pseudospot seed to use; default of -1 corresponds to 623",
    )

    args = parser.parse_args([
        f"--modelname={MODEL_NAME}",
        f"--config_fname={CONFIG_FNAME}",
        "--njobs=-1",
        "--test",
        f"--model_dir={MODEL_DIR}",
        "--results_dir=results_FINAL",
        f"--seed_override={model_seed}",
        f"--ps_seed={ps_seed}",
        f"--cuda=-1"
    ])

    script_start_time = datetime.datetime.now(datetime.timezone.utc)
    logger = logging.getLogger(__name__)
    logging.basicConfig(
        level=logging.WARNING,
        format="%(asctime)s:%(levelname)s:%(name)s:%(message)s",
    )
    with tf.device(get_device(args.cuda)):
        a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
        b = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
        c = tf.matmul(a, b)

        print(c)

        main(args)
    print("Script run time:", datetime.datetime.now(datetime.timezone.utc) - script_start_time)

Tensor("MatMul:0", shape=(2, 2), dtype=float32, device=/device:CPU:0)
Evaluating CellDART_original on with -1 jobs
Using library config:
None
Loading config basic_config.yml ... 
data_params:
  all_genes: false
  data_dir: ../AGrEDA/data
  dset: dlpfc
  n_markers: 20
  n_mix: 8
  n_spots: 100000
  samp_split: true
  sc_id: GSE144136
  scaler_name: minmax
  st_id: spatialLIBD
  st_split: false
lib_params: {}
model_params:
  celldart_kwargs:
    bn_momentum: 0.01
    emb_dim: 64
  model_version: std
train_params:
  alpha: 0.6
  alpha_lr: 5
  batch_size: 512
  initial_train_epochs: 10
  lr: 0.001
  n_iter: 15000
  pretraining: true
  reverse_val: true

Saving results to results_FINAL/CellDART_original/dlpfc/GSE144136_spatialLIBD/20markers/8mix_100000spots/minmax/std/2353 ...
Loading Data


  self.results_folder = self.temp_folder_holder.set_output_folder(


Loading ST adata: 
Getting predictions: 


  updates=self.state_updates,


Plotting Samples
[Parallel(n_jobs=12)]: Using backend LokyBackend with 12 concurrent workers.
[Parallel(n_jobs=12)]: Done   1 tasks      | elapsed:   16.7s
[Parallel(n_jobs=12)]: Done   2 out of  12 | elapsed:   16.8s remaining:  1.4min
[Parallel(n_jobs=12)]: Done   3 out of  12 | elapsed:   17.2s remaining:   51.5s
[Parallel(n_jobs=12)]: Done   4 out of  12 | elapsed:   17.5s remaining:   34.9s
[Parallel(n_jobs=12)]: Done   5 out of  12 | elapsed:   17.5s remaining:   24.5s
[Parallel(n_jobs=12)]: Done   6 out of  12 | elapsed:   17.6s remaining:   17.6s
[Parallel(n_jobs=12)]: Done   7 out of  12 | elapsed:   18.0s remaining:   12.8s
[Parallel(n_jobs=12)]: Done   8 out of  12 | elapsed:   18.0s remaining:    9.0s
[Parallel(n_jobs=12)]: Done   9 out of  12 | elapsed:   18.0s remaining:    6.0s
[Parallel(n_jobs=12)]: Done  10 out of  12 | elapsed:   18.1s remaining:    3.6s
[Parallel(n_jobs=12)]: Done  12 out of  12 | elapsed:   18.3s remaining:    0.0s
[Parallel(n_jobs=12)]: Done  12 ou

  updates=self.state_updates,


 milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151508: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151509: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151510: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151669: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151670: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151671: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151672: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151674: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151676: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 15167

  updates=self.state_updates,
  updates=self.state_updates,


                                          Pseudospots (Cosine Distance)  \
                                                                  train   
                       SC Split Sample ID                                 
Before DA              train    151507                         0.200982   
                                151508                         0.200982   
                                151509                         0.200982   
                                151510                         0.200982   
                                151669                         0.200982   
                                151670                         0.200982   
                                151671                         0.200982   
                                151672                         0.200982   
                                151674                         0.200982   
                                151676                         0.200982   
                       va

  self.results_folder = self.temp_folder_holder.set_output_folder(


Loading ST adata: 
Getting predictions: 


  updates=self.state_updates,


Plotting Samples
[Parallel(n_jobs=12)]: Using backend LokyBackend with 12 concurrent workers.
[Parallel(n_jobs=12)]: Done   1 tasks      | elapsed:   16.5s
[Parallel(n_jobs=12)]: Done   2 out of  12 | elapsed:   16.7s remaining:  1.4min
[Parallel(n_jobs=12)]: Done   3 out of  12 | elapsed:   17.1s remaining:   51.2s
[Parallel(n_jobs=12)]: Done   4 out of  12 | elapsed:   17.1s remaining:   34.2s
[Parallel(n_jobs=12)]: Done   5 out of  12 | elapsed:   17.3s remaining:   24.3s
[Parallel(n_jobs=12)]: Done   6 out of  12 | elapsed:   17.4s remaining:   17.4s
[Parallel(n_jobs=12)]: Done   7 out of  12 | elapsed:   17.5s remaining:   12.5s
[Parallel(n_jobs=12)]: Done   8 out of  12 | elapsed:   17.6s remaining:    8.8s
[Parallel(n_jobs=12)]: Done   9 out of  12 | elapsed:   17.6s remaining:    5.9s
[Parallel(n_jobs=12)]: Done  10 out of  12 | elapsed:   17.8s remaining:    3.6s
[Parallel(n_jobs=12)]: Done  12 out of  12 | elapsed:   17.9s remaining:    0.0s
[Parallel(n_jobs=12)]: Done  12 ou

  updates=self.state_updates,


 milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151508: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151509: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151510: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151669: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151670: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151671: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151672: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151674: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151676: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 15167

  updates=self.state_updates,
  updates=self.state_updates,


                                          Pseudospots (Cosine Distance)  \
                                                                  train   
                       SC Split Sample ID                                 
Before DA              train    151507                         0.161803   
                                151508                         0.161803   
                                151509                         0.161803   
                                151510                         0.161803   
                                151669                         0.161803   
                                151670                         0.161803   
                                151671                         0.161803   
                                151672                         0.161803   
                                151674                         0.161803   
                                151676                         0.161803   
                       va

  self.results_folder = self.temp_folder_holder.set_output_folder(


Loading ST adata: 
Getting predictions: 


  updates=self.state_updates,


Plotting Samples
[Parallel(n_jobs=12)]: Using backend LokyBackend with 12 concurrent workers.
[Parallel(n_jobs=12)]: Done   1 tasks      | elapsed:   23.3s
[Parallel(n_jobs=12)]: Done   2 out of  12 | elapsed:   25.4s remaining:  2.1min
[Parallel(n_jobs=12)]: Done   3 out of  12 | elapsed:   25.4s remaining:  1.3min
[Parallel(n_jobs=12)]: Done   4 out of  12 | elapsed:   25.4s remaining:   50.7s
[Parallel(n_jobs=12)]: Done   5 out of  12 | elapsed:   25.4s remaining:   35.5s
[Parallel(n_jobs=12)]: Done   6 out of  12 | elapsed:   25.4s remaining:   25.4s
[Parallel(n_jobs=12)]: Done   7 out of  12 | elapsed:   25.4s remaining:   18.1s
[Parallel(n_jobs=12)]: Done   8 out of  12 | elapsed:   25.4s remaining:   12.7s
[Parallel(n_jobs=12)]: Done   9 out of  12 | elapsed:   25.4s remaining:    8.5s
[Parallel(n_jobs=12)]: Done  10 out of  12 | elapsed:   25.4s remaining:    5.1s
[Parallel(n_jobs=12)]: Done  12 out of  12 | elapsed:   25.8s remaining:    0.0s
[Parallel(n_jobs=12)]: Done  12 ou

  updates=self.state_updates,


 milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151508: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151509: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151510: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151669: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151670: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151671: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151672: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151674: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151676: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 15167

  updates=self.state_updates,
  updates=self.state_updates,


                                          Pseudospots (Cosine Distance)  \
                                                                  train   
                       SC Split Sample ID                                 
Before DA              train    151507                         0.161534   
                                151508                         0.161534   
                                151509                         0.161534   
                                151510                         0.161534   
                                151669                         0.161534   
                                151670                         0.161534   
                                151671                         0.161534   
                                151672                         0.161534   
                                151674                         0.161534   
                                151676                         0.161534   
                       va

  self.results_folder = self.temp_folder_holder.set_output_folder(


Loading ST adata: 
Getting predictions: 


  updates=self.state_updates,


Plotting Samples
[Parallel(n_jobs=12)]: Using backend LokyBackend with 12 concurrent workers.
[Parallel(n_jobs=12)]: Done   1 tasks      | elapsed:   26.5s
[Parallel(n_jobs=12)]: Done   2 out of  12 | elapsed:   26.5s remaining:  2.2min
[Parallel(n_jobs=12)]: Done   3 out of  12 | elapsed:   26.6s remaining:  1.3min
[Parallel(n_jobs=12)]: Done   4 out of  12 | elapsed:   26.8s remaining:   53.5s
[Parallel(n_jobs=12)]: Done   5 out of  12 | elapsed:   27.3s remaining:   38.2s
[Parallel(n_jobs=12)]: Done   6 out of  12 | elapsed:   27.3s remaining:   27.3s
[Parallel(n_jobs=12)]: Done   7 out of  12 | elapsed:   27.5s remaining:   19.6s
[Parallel(n_jobs=12)]: Done   8 out of  12 | elapsed:   27.6s remaining:   13.8s
[Parallel(n_jobs=12)]: Done   9 out of  12 | elapsed:   27.8s remaining:    9.3s
[Parallel(n_jobs=12)]: Done  10 out of  12 | elapsed:   27.8s remaining:    5.6s
[Parallel(n_jobs=12)]: Done  12 out of  12 | elapsed:   27.9s remaining:    0.0s
[Parallel(n_jobs=12)]: Done  12 ou

  updates=self.state_updates,


 milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151508: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151509: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151510: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151669: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151670: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151671: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151672: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151674: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151676: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 15167

  updates=self.state_updates,
  updates=self.state_updates,


                                          Pseudospots (Cosine Distance)  \
                                                                  train   
                       SC Split Sample ID                                 
Before DA              train    151507                         0.140565   
                                151508                         0.140565   
                                151509                         0.140565   
                                151510                         0.140565   
                                151669                         0.140565   
                                151670                         0.140565   
                                151671                         0.140565   
                                151672                         0.140565   
                                151674                         0.140565   
                                151676                         0.140565   
                       va

  self.results_folder = self.temp_folder_holder.set_output_folder(


Loading ST adata: 
Getting predictions: 


  updates=self.state_updates,


Plotting Samples
[Parallel(n_jobs=12)]: Using backend LokyBackend with 12 concurrent workers.
[Parallel(n_jobs=12)]: Done   1 tasks      | elapsed:   33.7s
[Parallel(n_jobs=12)]: Done   2 out of  12 | elapsed:   33.7s remaining:  2.8min
[Parallel(n_jobs=12)]: Done   3 out of  12 | elapsed:   34.2s remaining:  1.7min
[Parallel(n_jobs=12)]: Done   4 out of  12 | elapsed:   34.2s remaining:  1.1min
[Parallel(n_jobs=12)]: Done   5 out of  12 | elapsed:   34.2s remaining:   47.9s
[Parallel(n_jobs=12)]: Done   6 out of  12 | elapsed:   34.2s remaining:   34.2s
[Parallel(n_jobs=12)]: Done   7 out of  12 | elapsed:   34.3s remaining:   24.5s
[Parallel(n_jobs=12)]: Done   8 out of  12 | elapsed:   34.4s remaining:   17.2s
[Parallel(n_jobs=12)]: Done   9 out of  12 | elapsed:   34.4s remaining:   11.5s
[Parallel(n_jobs=12)]: Done  10 out of  12 | elapsed:   34.4s remaining:    6.9s
[Parallel(n_jobs=12)]: Done  12 out of  12 | elapsed:   34.5s remaining:    0.0s
[Parallel(n_jobs=12)]: Done  12 ou

  updates=self.state_updates,


 milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151508: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151509: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151510: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151669: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151670: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151671: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151672: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151674: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 151676: TRAIN | milisi rf50 | VAL | milisi rf50 | TEST | milisi rf50 | 
Calculating domain shift for 15167

  updates=self.state_updates,
  updates=self.state_updates,


                                          Pseudospots (Cosine Distance)  \
                                                                  train   
                       SC Split Sample ID                                 
Before DA              train    151507                         0.166780   
                                151508                         0.166780   
                                151509                         0.166780   
                                151510                         0.166780   
                                151669                         0.166780   
                                151670                         0.166780   
                                151671                         0.166780   
                                151672                         0.166780   
                                151674                         0.166780   
                                151676                         0.166780   
                       va

