In [1]:
import anndata

In [2]:
adata = anndata.read_h5ad('/lustre/scratch126/cellgen/team292/ha10/data/Heart_Atlas/adata_Heart_Reichart_HV_train.h5ad')

In [3]:
adata

AnnData object with n_obs × n_vars = 284727 × 5000
    obs: 'Sample', 'donor_id', 'Region_x', 'Primary.Genetic.Diagnosis', 'n_genes', 'n_counts', 'percent_mito', 'percent_ribo', 'scrublet_score_z', 'scrublet_score_log', 'solo_score', 'cell_states', 'Assigned', 'self_reported_ethnicity_ontology_term_id', 'disease_ontology_term_id', 'cell_type_ontology_term_id', 'sex_ontology_term_id', 'assay_ontology_term_id', 'organism_ontology_term_id', 'is_primary_data', 'tissue_ontology_term_id', 'development_stage_ontology_term_id', 'suspension_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'disease_renamed'
    var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'ENS'
    uns: 'Primary.Genetic.Diagnosis_colors', 'Region_x_colors', 'cell_states_colors', 'cell_type_colors', 'cell_type_ontology_term_id_colors', 'diseas

In [4]:
from scDisInFact import scdisinfact, create_scdisinfact_dataset
data_dict = create_scdisinfact_dataset(adata.layers['counts'], adata.obs, condition_key = ["disease"], batch_key = "donor_id")

Sanity check...
Finished.
Create scDisInFact datasets...
Finished.


In [13]:
len(data_dict['datasets'])

22

In [14]:
import torch

In [15]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [7]:
model = scdisinfact(data_dict = data_dict, device = device)

In [8]:
losses = model.train_model()

Epoch 0, Validating Loss: 1.0202
	 loss reconstruction: 0.19029
	 loss kl comm: 27.58274
	 loss kl diff: 5.89108
	 loss mmd common: 16.92033
	 loss mmd diff: 12.89788
	 loss classification: 0.67658
	 loss group lasso diff: 0.09114
GPU memory usage: 301.405762MB
Epoch 10, Validating Loss: 0.5895
	 loss reconstruction: 0.15013
	 loss kl comm: 40.58490
	 loss kl diff: 6.23135
	 loss mmd common: 15.13281
	 loss mmd diff: 14.46283
	 loss classification: 0.27565
	 loss group lasso diff: 0.09801
GPU memory usage: 301.443848MB
Epoch 20, Validating Loss: 0.5306
	 loss reconstruction: 0.14690
	 loss kl comm: 28.58654
	 loss kl diff: 5.35616
	 loss mmd common: 10.30627
	 loss mmd diff: 13.66662
	 loss classification: 0.21939
	 loss group lasso diff: 0.10809
GPU memory usage: 301.443848MB
Epoch 30, Validating Loss: 0.4896
	 loss reconstruction: 0.14546
	 loss kl comm: 25.33405
	 loss kl diff: 5.00176
	 loss mmd common: 8.59200
	 loss mmd diff: 11.54021
	 loss classification: 0.17344
	 loss group l

In [9]:
torch.save(model.state_dict(), f"model_scdisinfact.pth")

In [10]:
_ = model.eval()

In [11]:
import numpy as np

In [17]:
model = torch.load('model_scdisinfact.pth', map_location=torch.device('cpu'))

  model = torch.load('model_scdisinfact.pth', map_location=torch.device('cpu'))


In [None]:
# one forward pass
z_cs = []
z_ds = []
zs = []

for dataset in data_dict["datasets"]:
    with torch.no_grad():
        # pass through the encoders
        dict_inf = model.inference(counts = dataset.counts_norm.to(model.device), batch_ids = dataset.batch_id[:,None].to(model.device), print_stat = True)
        # pass through the decoder
        dict_gen = model.generative(z_c = dict_inf["mu_c"], z_d = dict_inf["mu_d"], batch_ids = dataset.batch_id[:,None].to(model.device))
        z_c = dict_inf["mu_c"]
        z_d = dict_inf["mu_d"]
        z = torch.cat([z_c] + z_d, dim = 1)
        mu = dict_gen["mu"]    
        z_ds.append([x.cpu().detach().numpy() for x in z_d])
        z_cs.append(z_c.cpu().detach().numpy())
        zs.append(np.concatenate([z_cs[-1]] + z_ds[-1], axis = 1))

In [13]:
len(zs)

22

In [14]:
zs0 = np.concatenate(zs, axis = 0)


In [15]:
zs0.shape

(284727, 12)

In [17]:
zs0

array([[-0.74021053,  0.9917302 ,  0.9823548 , ..., -0.1353738 ,
         1.5467132 ,  1.1094648 ],
       [-0.39822677, -0.6310626 , -0.4166441 , ...,  1.4063932 ,
         5.859304  ,  2.2182906 ],
       [-1.5629592 ,  0.45029968, -1.3686004 , ...,  0.14934751,
         4.8232665 ,  3.183637  ],
       ...,
       [-1.7748965 , -1.1773459 ,  0.2297636 , ...,  0.66093093,
         4.786968  ,  2.7253006 ],
       [-0.39350608, -0.83237445, -0.34193796, ...,  0.44534716,
         5.2973104 ,  2.7139065 ],
       [-1.0069495 ,  1.1501207 ,  0.7784806 , ...,  0.88178456,
         4.6801953 ,  2.815968  ]], dtype=float32)

In [18]:
zs0.tofile('scdisinfact_embed.csv', sep = ',')