In [4]:
import os
import scanpy as sc
import scvi
import json
from sklearn.model_selection import train_test_split
import numpy as np
import json

## Data download

In [9]:

# Define data paths
data_dir = "data_input"
os.makedirs(data_dir, exist_ok=True)

pancreas_adata_path = os.path.join(data_dir, "pancreas_full.h5ad")
train_path = os.path.join(data_dir, "pancreas_train.h5ad")
valid_path = os.path.join(data_dir, "pancreas_valid.h5ad")
test_path  = os.path.join(data_dir, "pancreas_test.h5ad")

# Download if missing, otherwise load from local file
pancreas_adata = sc.read(
    pancreas_adata_path,
    backup_url="https://figshare.com/ndownloader/files/24539828",
)

# Split dataset by technology: keep smartseq2/celseq2 as held-out test
query_mask = pancreas_adata.obs["tech"].isin(["smartseq2", "celseq2"]).to_numpy()
pancreas_no_test = pancreas_adata[~query_mask].copy()
pancreas_test    = pancreas_adata[ query_mask].copy()

# 80/20 train/valid split on the remaining data, stratified by technology
y = pancreas_no_test.obs["tech"].astype("category")
indices = np.arange(pancreas_no_test.n_obs)

idx_train, idx_valid = train_test_split(
    indices,
    test_size=0.20,
    train_size=0.80,
    random_state=42,
    shuffle=True,
    stratify=y  # stratify by technology
)

pancreas_train = pancreas_no_test[idx_train].copy()
pancreas_valid = pancreas_no_test[idx_valid].copy()

# Save splits
pancreas_train.write(train_path)
pancreas_valid.write(valid_path)
pancreas_test.write(test_path)

print(
    f"Train: {pancreas_train.n_obs} cells | "
    f"Valid: {pancreas_valid.n_obs} cells | "
    f"Test: {pancreas_test.n_obs} cells"
)

# Print counts per technology
print("\nCells per technology:")
for name, ad in [("Train", pancreas_train),
                 ("Valid", pancreas_valid),
                 ("Test", pancreas_test)]:
    counts = ad.obs["tech"].value_counts().sort_index()
    print(f"\n{name} split:")
    for tech, n in counts.items():
        print(f"  {tech}: {n}")

# --- Cleanup: delete the original full dataset file ---
del pancreas_adata  # drop reference to ensure no open handle
try:
    if os.path.exists(pancreas_adata_path):
        os.remove(pancreas_adata_path)
        print(f"Deleted '{pancreas_adata_path}'")
except Exception as e:
    print(f"[WARN] Could not delete '{pancreas_adata_path}': {e}")



100%|██████████| 301M/301M [00:15<00:00, 20.9MB/s] 


Train: 9362 cells | Valid: 2341 cells | Test: 4679 cells

Cells per technology:

Train split:
  celseq: 803
  fluidigmc1: 510
  inDrop1: 1550
  inDrop2: 1379
  inDrop3: 2884
  inDrop4: 1042
  smarter: 1194

Valid split:
  celseq: 201
  fluidigmc1: 128
  inDrop1: 387
  inDrop2: 345
  inDrop3: 721
  inDrop4: 261
  smarter: 298

Test split:
  celseq2: 2285
  smartseq2: 2394
Deleted 'data_input/pancreas_full.h5ad'
Saved 19093 genes to data_input/all_genes_list.json


In [6]:
# Utility to load HVG list
def load_hvg_list(hvg_list_path):
    with open(hvg_list_path) as f:
        return json.load(f)

hvg_list = load_hvg_list("data_input/hvg_list.json")

# Restrict to HVG genes
pancreas_train = pancreas_train[:, hvg_list].copy()

## Train of scVI model


In [7]:
scvi.model.SCVI.setup_anndata(pancreas_train, batch_key="tech", layer="counts")

scvi_ref = scvi.model.SCVI(
    pancreas_train,
    use_layer_norm="both",
    use_batch_norm="none",
    encode_covariates=True,
    dropout_rate=0.2,
    n_layers=2,
)
scvi_ref.train(max_epochs=50)

  self.validate_field(adata)
  accelerator, lightning_devices, device = parse_device_args(
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/dmalpetti/miniconda3/envs/fl-course-env/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/dmalpetti/miniconda3/envs/fl-course-env/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 1/50:   0%|          | 0/50 [00:00<?, ?it/s]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 2/50:   2%|▏         | 1/50 [00:01<00:50,  1.03s/it, v_num=1, train_loss_step=1.09e+3, train_loss_epoch=1.41e+3]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 3/50:   4%|▍         | 2/50 [00:02<00:49,  1.02s/it, v_num=1, train_loss_step=1.02e+3, train_loss_epoch=1.06e+3]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 4/50:   6%|▌         | 3/50 [00:03<00:57,  1.22s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=991]    

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 5/50:   8%|▊         | 4/50 [00:04<00:55,  1.20s/it, v_num=1, train_loss_step=1.03e+3, train_loss_epoch=942]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 6/50:  10%|█         | 5/50 [00:05<00:53,  1.18s/it, v_num=1, train_loss_step=855, train_loss_epoch=916]    

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 7/50:  12%|█▏        | 6/50 [00:06<00:50,  1.14s/it, v_num=1, train_loss_step=922, train_loss_epoch=901]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 8/50:  14%|█▍        | 7/50 [00:07<00:47,  1.11s/it, v_num=1, train_loss_step=925, train_loss_epoch=890]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 9/50:  16%|█▌        | 8/50 [00:08<00:45,  1.07s/it, v_num=1, train_loss_step=825, train_loss_epoch=882]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 10/50:  18%|█▊        | 9/50 [00:10<00:44,  1.08s/it, v_num=1, train_loss_step=863, train_loss_epoch=874]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 11/50:  20%|██        | 10/50 [00:11<00:44,  1.11s/it, v_num=1, train_loss_step=811, train_loss_epoch=868]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 12/50:  22%|██▏       | 11/50 [00:12<00:43,  1.11s/it, v_num=1, train_loss_step=747, train_loss_epoch=863]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 13/50:  24%|██▍       | 12/50 [00:13<00:42,  1.11s/it, v_num=1, train_loss_step=902, train_loss_epoch=858]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 14/50:  26%|██▌       | 13/50 [00:14<00:39,  1.06s/it, v_num=1, train_loss_step=803, train_loss_epoch=854]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 15/50:  28%|██▊       | 14/50 [00:15<00:38,  1.08s/it, v_num=1, train_loss_step=854, train_loss_epoch=851]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 16/50:  30%|███       | 15/50 [00:16<00:41,  1.18s/it, v_num=1, train_loss_step=933, train_loss_epoch=847]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 17/50:  32%|███▏      | 16/50 [00:18<00:40,  1.19s/it, v_num=1, train_loss_step=862, train_loss_epoch=845]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 18/50:  34%|███▍      | 17/50 [00:19<00:38,  1.18s/it, v_num=1, train_loss_step=748, train_loss_epoch=842]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 19/50:  36%|███▌      | 18/50 [00:20<00:36,  1.15s/it, v_num=1, train_loss_step=711, train_loss_epoch=840]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 20/50:  38%|███▊      | 19/50 [00:21<00:34,  1.11s/it, v_num=1, train_loss_step=767, train_loss_epoch=837]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 21/50:  40%|████      | 20/50 [00:22<00:33,  1.13s/it, v_num=1, train_loss_step=803, train_loss_epoch=835]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 22/50:  42%|████▏     | 21/50 [00:23<00:33,  1.17s/it, v_num=1, train_loss_step=884, train_loss_epoch=833]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 23/50:  44%|████▍     | 22/50 [00:24<00:31,  1.14s/it, v_num=1, train_loss_step=894, train_loss_epoch=831]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 24/50:  46%|████▌     | 23/50 [00:25<00:29,  1.09s/it, v_num=1, train_loss_step=787, train_loss_epoch=829]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 25/50:  48%|████▊     | 24/50 [00:26<00:27,  1.06s/it, v_num=1, train_loss_step=737, train_loss_epoch=827]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 26/50:  50%|█████     | 25/50 [00:27<00:26,  1.05s/it, v_num=1, train_loss_step=863, train_loss_epoch=826]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 27/50:  52%|█████▏    | 26/50 [00:29<00:26,  1.10s/it, v_num=1, train_loss_step=854, train_loss_epoch=824]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 28/50:  54%|█████▍    | 27/50 [00:30<00:24,  1.07s/it, v_num=1, train_loss_step=845, train_loss_epoch=823]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 29/50:  56%|█████▌    | 28/50 [00:31<00:22,  1.03s/it, v_num=1, train_loss_step=869, train_loss_epoch=821]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 30/50:  58%|█████▊    | 29/50 [00:32<00:21,  1.01s/it, v_num=1, train_loss_step=796, train_loss_epoch=820]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 31/50:  60%|██████    | 30/50 [00:33<00:20,  1.03s/it, v_num=1, train_loss_step=871, train_loss_epoch=819]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 32/50:  62%|██████▏   | 31/50 [00:34<00:19,  1.00s/it, v_num=1, train_loss_step=786, train_loss_epoch=817]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 33/50:  64%|██████▍   | 32/50 [00:35<00:17,  1.00it/s, v_num=1, train_loss_step=799, train_loss_epoch=816]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 34/50:  66%|██████▌   | 33/50 [00:36<00:17,  1.02s/it, v_num=1, train_loss_step=791, train_loss_epoch=816]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 35/50:  68%|██████▊   | 34/50 [00:37<00:16,  1.03s/it, v_num=1, train_loss_step=907, train_loss_epoch=814]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 36/50:  70%|███████   | 35/50 [00:38<00:15,  1.00s/it, v_num=1, train_loss_step=850, train_loss_epoch=813]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 37/50:  72%|███████▏  | 36/50 [00:39<00:13,  1.01it/s, v_num=1, train_loss_step=868, train_loss_epoch=812]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 38/50:  74%|███████▍  | 37/50 [00:40<00:12,  1.01it/s, v_num=1, train_loss_step=817, train_loss_epoch=811]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 39/50:  76%|███████▌  | 38/50 [00:41<00:12,  1.00s/it, v_num=1, train_loss_step=780, train_loss_epoch=810]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 40/50:  78%|███████▊  | 39/50 [00:42<00:10,  1.00it/s, v_num=1, train_loss_step=752, train_loss_epoch=809]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 41/50:  80%|████████  | 40/50 [00:43<00:09,  1.00it/s, v_num=1, train_loss_step=815, train_loss_epoch=808]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 42/50:  82%|████████▏ | 41/50 [00:44<00:09,  1.00s/it, v_num=1, train_loss_step=817, train_loss_epoch=807]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 43/50:  84%|████████▍ | 42/50 [00:45<00:08,  1.02s/it, v_num=1, train_loss_step=787, train_loss_epoch=807]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 44/50:  86%|████████▌ | 43/50 [00:46<00:07,  1.01s/it, v_num=1, train_loss_step=818, train_loss_epoch=806]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 45/50:  88%|████████▊ | 44/50 [00:47<00:06,  1.04s/it, v_num=1, train_loss_step=763, train_loss_epoch=806]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 46/50:  90%|█████████ | 45/50 [00:48<00:05,  1.03s/it, v_num=1, train_loss_step=804, train_loss_epoch=805]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 47/50:  92%|█████████▏| 46/50 [00:49<00:04,  1.05s/it, v_num=1, train_loss_step=961, train_loss_epoch=804]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 48/50:  94%|█████████▍| 47/50 [00:50<00:03,  1.04s/it, v_num=1, train_loss_step=807, train_loss_epoch=803]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 49/50:  96%|█████████▌| 48/50 [00:51<00:02,  1.06s/it, v_num=1, train_loss_step=786, train_loss_epoch=803]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 50/50:  98%|█████████▊| 49/50 [00:52<00:01,  1.07s/it, v_num=1, train_loss_step=702, train_loss_epoch=802]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 50/50: 100%|██████████| 50/50 [00:53<00:00,  1.10s/it, v_num=1, train_loss_step=656, train_loss_epoch=802]

`Trainer.fit` stopped: `max_epochs=50` reached.


Epoch 50/50: 100%|██████████| 50/50 [00:53<00:00,  1.07s/it, v_num=1, train_loss_step=656, train_loss_epoch=802]


In [8]:
scvi_ref.save("model_centralized", overwrite=True)