In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import anndata as ad
import pandas as pd
import scanpy as sc
import numpy as np
import torch
from scETM import scETM, UnsupervisedTrainer, evaluate, prepare_for_transfer
sc.set_figure_params(dpi=120, dpi_save=250, fontsize=10, figsize=(10, 10), facecolor="white")

In [3]:
adata = sc.read_h5ad('data/sc_training.h5ad')
adata.obs["batch_indices"] = np.array([s[-1] for s in list(adata.obs.lane)])
emb_dim = 32

Unsupervised learning to reconstruct scRNA-seq data

In [1]:
model = scETM(adata.n_vars, 
              n_batches=4, 
              trainable_gene_emb_dim=emb_dim,
             )
trainer = UnsupervisedTrainer(model, adata, 
                              test_ratio=0.01,
                              ckpt_dir = "submission/checkpoints/",
                              init_lr = 1e-4,
                              batch_size = 16000,
                              seed = 24
                             )

trainer.train(n_epochs = 12000, 
              eval_every = 6000, 
#               min_kl_weight = 0.1,
#               max_kl_weight = 0.5,
              eval_kwargs = dict(cell_type_col = 'state'), 
              save_model_ckpt = True)


[2023-01-14 12:57:30,829] INFO - scETM.logging_utils: scETM.__init__(15077, n_batches = 4, trainable_gene_emb_dim = 32)
[2023-01-14 12:57:32,221] INFO - scETM.logging_utils: UnsupervisedTrainer.__init__(scETM(
  (q_delta): Sequential(
    (0): Linear(in_features=15077, out_features=128, bias=True)
    (1): ReLU()
    (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.1, inplace=False)
  )
  (mu_q_delta): Linear(in_features=128, out_features=50, bias=True)
  (logsigma_q_delta): Linear(in_features=128, out_features=50, bias=True)
  (rho_trainable_emb): PartlyTrainableParameter2D(height=32, fixed=0, trainable=15077)
), AnnData object with n_obs × n_vars = 28697 × 15077
    obs: 'gRNA_maxID', 'state', 'condition', 'lane', 'batch_indices'
    layers: 'rawcounts', test_ratio = 0.01, ckpt_dir = submission/checkpoints/, init_lr = 0.0001, batch_size = 16000, seed = 24)
[2023-01-14 12:57:32,222] INFO - scETM.trainers.trainer_utils: Set seed

loss:      10.83	nll:      10.83	kl_delta:     0.2007	max_norm:     0.2975	Epoch     0/12000	Next ckpt:       0

[2023-01-14 12:57:36,592] INFO - scETM.trainers.UnsupervisedTrainer: pmem(rss=15544811520, vms=34319380480, shared=723554304, text=2043904, lib=0, data=18835697664, dirty=0)
[2023-01-14 12:57:36,593] INFO - scETM.trainers.UnsupervisedTrainer: lr          :    9.999e-05
[2023-01-14 12:57:36,594] INFO - scETM.trainers.UnsupervisedTrainer: kl_weight   :            0
[2023-01-14 12:57:36,595] INFO - scETM.trainers.trainer_utils: loss        :      10.83
[2023-01-14 12:57:36,596] INFO - scETM.trainers.trainer_utils: nll         :      10.83
[2023-01-14 12:57:36,599] INFO - scETM.trainers.trainer_utils: kl_delta    :     0.2007
[2023-01-14 12:57:36,599] INFO - scETM.trainers.trainer_utils: max_norm    :     0.2975
[2023-01-14 12:57:36,634] INFO - scETM.trainers.UnsupervisedTrainer: test nll: 10.3841
[2023-01-14 12:57:38,378] INFO - scETM.logging_utils: evaluate(adata = AnnData object with n_obs × n_vars = 28697 × 15077
    obs: 'gRNA_maxID', 'state', 'condition', 'lane', 'batch_indices'
    

loss:      8.381	nll:      8.381	kl_delta:      268.7	max_norm:     0.2632	Epoch  5999/12000	Next ckpt:    6000

[2023-01-14 14:04:06,616] INFO - scETM.trainers.UnsupervisedTrainer: pmem(rss=18963922944, vms=47088144384, shared=1049296896, text=2043904, lib=0, data=22971084800, dirty=0)
INFO:scETM.trainers.UnsupervisedTrainer:pmem(rss=18963922944, vms=47088144384, shared=1049296896, text=2043904, lib=0, data=22971084800, dirty=0)
[2023-01-14 14:04:06,618] INFO - scETM.trainers.UnsupervisedTrainer: lr          :    5.277e-05
INFO:scETM.trainers.UnsupervisedTrainer:lr          :    5.277e-05
[2023-01-14 14:04:06,620] INFO - scETM.trainers.UnsupervisedTrainer: kl_weight   :        1e-07
INFO:scETM.trainers.UnsupervisedTrainer:kl_weight   :        1e-07
[2023-01-14 14:04:06,626] INFO - scETM.trainers.trainer_utils: loss        :      8.793
INFO:scETM.trainers.trainer_utils:loss        :      8.793
[2023-01-14 14:04:06,629] INFO - scETM.trainers.trainer_utils: nll         :      8.793
INFO:scETM.trainers.trainer_utils:nll         :      8.793
[2023-01-14 14:04:06,632] INFO - scETM.trainers.trainer_uti

loss:      8.334	nll:      8.333	kl_delta:      328.4	max_norm:     0.7902	Epoch 11999/12000	Next ckpt:   12000

[2023-01-14 15:19:14,273] INFO - scETM.trainers.UnsupervisedTrainer: pmem(rss=12004614144, vms=43964141568, shared=1049694208, text=2043904, lib=0, data=14259892224, dirty=0)
INFO:scETM.trainers.UnsupervisedTrainer:pmem(rss=12004614144, vms=43964141568, shared=1049694208, text=2043904, lib=0, data=14259892224, dirty=0)
[2023-01-14 15:19:14,275] INFO - scETM.trainers.UnsupervisedTrainer: lr          :    2.784e-05
INFO:scETM.trainers.UnsupervisedTrainer:lr          :    2.784e-05
[2023-01-14 15:19:14,277] INFO - scETM.trainers.UnsupervisedTrainer: kl_weight   :        1e-07
INFO:scETM.trainers.UnsupervisedTrainer:kl_weight   :        1e-07
[2023-01-14 15:19:14,281] INFO - scETM.trainers.trainer_utils: loss        :      8.346
INFO:scETM.trainers.trainer_utils:loss        :      8.346
[2023-01-14 15:19:14,283] INFO - scETM.trainers.trainer_utils: nll         :      8.346
INFO:scETM.trainers.trainer_utils:nll         :      8.346
[2023-01-14 15:19:14,284] INFO - scETM.trainers.trainer_uti

Extract gene embedding vectors

In [15]:
model = scETM(adata.n_vars, 
              n_batches=4, 
              trainable_gene_emb_dim=emb_dim,
             )

model.load_state_dict(torch.load("./submission/checkpoints/scETM_01_14-12_57_32/model-12000"))
model.eval()

[2023-01-18 11:36:42,294] INFO - scETM.logging_utils: scETM.__init__(15077, n_batches = 4, trainable_gene_emb_dim = 32)


scETM(
  (q_delta): Sequential(
    (0): Linear(in_features=15077, out_features=128, bias=True)
    (1): ReLU()
    (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.1, inplace=False)
  )
  (mu_q_delta): Linear(in_features=128, out_features=50, bias=True)
  (logsigma_q_delta): Linear(in_features=128, out_features=50, bias=True)
  (rho_trainable_emb): PartlyTrainableParameter2D(height=32, fixed=0, trainable=15077)
)

In [9]:
model.get_all_embeddings_and_nll(adata)

gene_embedding = np.array(adata.varm['rho'])
np.save(f"./submission/embedding/gene_embedding_{emb_dim}", gene_embedding)

gene_names = np.array(adata.var_names)
np.save(f"./submission/embedding/gene_names_{emb_dim}", gene_names)