In [2]:
import scanpy as sc


files = {
    "pbmc_4k": "data/processed/pbmc_4k_clean_qc.h5ad",
    "pbmc_8k": "data/processed/pbmc_8k_clean_qc.h5ad",
    "pbmc_3p": "data/processed/pbmc_3p_clean_qc.h5ad",
    "pbmc_5p": "data/processed/pbmc_5p_clean_qc.h5ad",
    "cd4t":    "data/processed/cd4t_qc.h5ad",
    "cd8t":    "data/processed/cd8t_d1_qc.h5ad",
    "nk":      "data/processed/nk_qc.h5ad",
    "gdt":     "data/processed/gdt_3donors_qc.h5ad",
}
# load processed datasets
adatas = {}
for f in files:
    adata = sc.read_h5ad(files[f])
    adatas[f] = adata
    print(f"{f}: {adata.n_obs} cells, {adata.n_vars} genes")

pbmc_4k: 1831 cells, 2000 genes
pbmc_8k: 3609 cells, 2000 genes
pbmc_3p: 1053 cells, 13945 genes
pbmc_5p: 815 cells, 13511 genes
cd4t: 11209 cells, 13580 genes
cd8t: 38364 cells, 20079 genes
nk: 8355 cells, 13757 genes
cd8t: 38364 cells, 20079 genes
nk: 8355 cells, 13757 genes
gdt: 8163 cells, 16295 genes
gdt: 8163 cells, 16295 genes


# Rebalance data
Start with a balanced dataset for the initial model:  
* GDT: 8,000 cells total
* Non-GDT: 8,000 cells total
    * CD4 T: 2,000 cells (hard negatives - closer to GDT)
    * CD8 T: 2,000 cells (hard negatives)
    * NK: 2,000 cells (hard negatives)
    * Either:
        * PBMC 3p/5p: 1,000 cells each (easy negatives)
        * PBMC 4k/8k: 1,000 cells each (easy negatives)

In [3]:
import data_functions as dfuncs
from importlib import reload
reload(dfuncs)


# Set the maximum number of cells to sample for each dataset
max_cells = {
    "gdt": 8000,
    "cd4t": 2000,
    "cd8t": 2000,
    "nk": 2000,
    "pbmc_4k": 1000,
    "pbmc_8k": 1000,
}
# Downsample datasets to balance cell numbers
sampled_adatas, unused_adatas = dfuncs.balance_datasets(adatas, max_cells)
val_max_cells = {
    "cd4t": 5000,
    "cd8t": 5000,
}
# Filter out None values from unused_adatas before passing to balance_datasets
filtered_unused_adatas = {k: v for k, v in unused_adatas.items() if v is not None}
val_purified_adatas, _ = dfuncs.balance_datasets(filtered_unused_adatas, val_max_cells)
# Set up adatas for integration
pbmc_4k8k_adatas = {
    "pbmc_4k":  sampled_adatas["pbmc_4k"],
    "pbmc_8k":  sampled_adatas["pbmc_8k"],
    "cd4t":     sampled_adatas["cd4t"],
    "cd8t":     sampled_adatas["cd8t"],
    "nk":       sampled_adatas["nk"],
    "gdt":      sampled_adatas["gdt"],
}
pbmc_3p5p_adatas = {
    "pbmc_3p":  sampled_adatas["pbmc_3p"],
    "pbmc_5p":  sampled_adatas["pbmc_5p"],
    "cd4t":     val_purified_adatas["cd4t"],
    "cd8t":     val_purified_adatas["cd8t"],
    "nk":       val_purified_adatas["nk"],
    "gdt":      val_purified_adatas["gdt"],
}

Balanced pbmc_4k: sampled 1000 cells from 1831
Balanced pbmc_8k: sampled 1000 cells from 3609
Balanced pbmc_3p: using all 1053 cells
Balanced pbmc_5p: using all 815 cells
Balanced cd4t: sampled 2000 cells from 11209
Balanced cd8t: sampled 2000 cells from 38364
Balanced nk: sampled 2000 cells from 8355
Balanced cd8t: sampled 2000 cells from 38364
Balanced nk: sampled 2000 cells from 8355
Balanced gdt: sampled 8000 cells from 8163
Balanced pbmc_4k: using all 831 cells
Balanced pbmc_8k: using all 2609 cells
Balanced cd4t: sampled 5000 cells from 9209
Balanced gdt: sampled 8000 cells from 8163
Balanced pbmc_4k: using all 831 cells
Balanced pbmc_8k: using all 2609 cells
Balanced cd4t: sampled 5000 cells from 9209
Balanced cd8t: sampled 5000 cells from 36364
Balanced nk: using all 6355 cells
Balanced gdt: using all 163 cells
Balanced cd8t: sampled 5000 cells from 36364
Balanced nk: using all 6355 cells
Balanced gdt: using all 163 cells


In [4]:
# Check cell counts of datasets
# PBMC 4k/8k integration datasets
print("PBMC 4k/8k integration datasets")
for name, adata in pbmc_4k8k_adatas.items():
    print(f"{name}: {adata.n_obs} cells")
    
# PBMC 3p/5p integration datasets
print("\nPBMC 3p/5p integration datasets")
for name, adata in pbmc_3p5p_adatas.items():
    if adata is not None:
        print(f"{name}: {adata.n_obs} cells")
    else:
        print(f"{name}: None")

PBMC 4k/8k integration datasets
pbmc_4k: 1000 cells
pbmc_8k: 1000 cells
cd4t: 2000 cells
cd8t: 2000 cells
nk: 2000 cells
gdt: 8000 cells

PBMC 3p/5p integration datasets
pbmc_3p: 1053 cells
pbmc_5p: 815 cells
cd4t: 5000 cells
cd8t: 5000 cells
nk: 6355 cells
gdt: 163 cells


# Integrate, align, log-normalize

## Adding prefix to cell barcodes

In [5]:
# Add prefix to cell barcodes to ensure uniqueness across datasets
for name, adata in pbmc_4k8k_adatas.items():
    adata.obs_names = [f"{name}_{bc}" for bc in adata.obs_names]

for name, adata in pbmc_3p5p_adatas.items():
    adata.obs_names = [f"{name}_{bc}" for bc in adata.obs_names]

# Check barcodes after adding prefixes
print("\nBarcodes after adding prefixes:")
print("PBMC 4k/8k datasets:")
for name, adata in pbmc_4k8k_adatas.items():
    print(f"{name}: {adata.obs_names[:2]}")

print("PBMC 3p/5p datasets:")
for name, adata in pbmc_3p5p_adatas.items():
    print(f"{name}: {adata.obs_names[:2]}")


Barcodes after adding prefixes:
PBMC 4k/8k datasets:
pbmc_4k: Index(['pbmc_4k_AGATCTGTCTGCGTAA-1', 'pbmc_4k_GGGTCTGCAACACCCG-1'], dtype='object')
pbmc_8k: Index(['pbmc_8k_TCAGGTAAGCACACAG-1', 'pbmc_8k_GCGAGAACATGTAAGA-1'], dtype='object')
cd4t: Index(['cd4t_CATCAACTTGTTCT-1', 'cd4t_CATGTACTGGCGAA-1'], dtype='object')
cd8t: Index(['cd8t_TACTTGTTCAGTCAGT-8', 'cd8t_GTTACAGGTGCGGTAA-14'], dtype='object')
nk: Index(['nk_GATTCGGATGGTAC-1', 'nk_TTCCAAACTTCTTG-1'], dtype='object')
gdt: Index(['gdt_donor3_d2_CGTGTCTAGAACAATC', 'gdt_donor2_d2_CGGTTAACACCCTATC'], dtype='object')
PBMC 3p/5p datasets:
pbmc_3p: Index(['pbmc_3p_0', 'pbmc_3p_1'], dtype='object')
pbmc_5p: Index(['pbmc_5p_0', 'pbmc_5p_1'], dtype='object')
cd4t: Index(['cd4t_AACGTCGAATACCG-1', 'cd4t_ATATGAACACGTAC-1'], dtype='object')
cd8t: Index(['cd8t_AAGGCAGTCGATCCCT-4', 'cd8t_GATGAAACAATGAATG-30'], dtype='object')
nk: Index(['nk_AAACATACCCAATG-1', 'nk_AAACATACGACGTT-1'], dtype='object')
gdt: Index(['gdt_donor1_d1_ACACCCTCACCATGTA', 

In [7]:
pbmc_4k8k_adatas

{'pbmc_4k': AnnData object with n_obs × n_vars = 1000 × 2000
     obs: 'celltype', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'leiden'
     var: 'gene_ids', 'n_cells', 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mean', 'std'
     uns: 'celltype_colors', 'hvg', 'leiden', 'leiden_colors', 'log1p', 'neighbors', 'pca', 'umap'
     obsm: 'X_pca', 'X_umap'
     varm: 'PCs'
     obsp: 'connectivities', 'distances',
 'pbmc_8k': AnnData object with n_obs × n_vars = 1000 × 2000
     obs: 'celltype', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'leiden'
     var: 'gene_ids', 'n_cells', 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mean', 'std'
     uns: 'celltype_colors', 'hvg', 'leiden', 'leiden_colors', 'log1p', 'ne

In [7]:
reload(dfuncs)
# Integrate balanced datasets using scVI
pbmc_4k8k_clean_gdt_ad, scvi_model = dfuncs.integrate_scvi(pbmc_4k8k_adatas, latent_dim=30)
pbmc_4k8k_clean_gdt_ad

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
TPU available: False, using: 0 TPU cores
/Users/thuvu/anaconda3/envs/modeling/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/thuvu/anaconda3/envs/modeling/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training:   0%|          | 0/400 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=400` reached.


AnnData object with n_obs × n_vars = 16000 × 451
    obs: 'celltype', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'dataset', '_scvi_batch', '_scvi_labels'
    uns: '_scvi_uuid', '_scvi_manager_uuid'
    obsm: 'X_scVI'
    layers: 'scVI_corrected', 'scVI_corrected_log'

In [8]:
# Save model
pbmc_4k8k_clean_gdt_ad.write_h5ad("data/ready/pbmc_4k8k_clean_gdt_scvi_integrated.h5ad")
scvi_model.save("data/ready/scvi_model", overwrite=True)

## Align PBMC 3p5p data to pre-trained 4k8k model

In [9]:
reload(dfuncs)
# Align the validation datasets using the trained model
pbmc_3p5p_clean_gdt_ad = dfuncs.align_to_scvi_model(
    pbmc_3p5p_adatas, 
    "data/ready/scvi_model",
    reference_adata=pbmc_4k8k_clean_gdt_ad,
)
pbmc_3p5p_clean_gdt_ad

[34mINFO    [0m File data/ready/scvi_model/model.pt already downloaded                                                    


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
TPU available: False, using: 0 TPU cores


Common genes: 409/451
Missing in query: 42
Training query model...


/Users/thuvu/anaconda3/envs/modeling/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training:   0%|          | 0/200 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=200` reached.


AnnData object with n_obs × n_vars = 18386 × 451
    obs: 'celltype', 'dataset', '_scvi_batch', '_scvi_labels'
    uns: '_scvi_uuid', '_scvi_manager_uuid'
    obsm: 'X_scVI'
    layers: 'scVI_corrected', 'scVI_corrected_log'

In [10]:
# Add binary label
pbmc_3p5p_clean_gdt_ad.obs["is_gdt"] = (pbmc_3p5p_clean_gdt_ad.obs["celltype"] == "GDT").astype(int)
pbmc_4k8k_clean_gdt_ad.obs["is_gdt"] = (pbmc_4k8k_clean_gdt_ad.obs["celltype"] == "GDT").astype(int)

In [13]:
# save the integrated validation data
dfuncs.save_h5ad(pbmc_4k8k_clean_gdt_ad, "pbmc_4k8k_clean_gdt_scvi_integrated.h5ad", "data/ready/")
dfuncs.save_h5ad(pbmc_3p5p_clean_gdt_ad, "pbmc_3p5p_clean_gdt_scvi_integrated.h5ad", "data/ready/")

Saved data/ready/pbmc_4k8k_clean_gdt_scvi_integrated.h5ad
Saved data/ready/pbmc_3p5p_clean_gdt_scvi_integrated.h5ad
Saved data/ready/pbmc_3p5p_clean_gdt_scvi_integrated.h5ad
