In [3]:
import anndata
# import scvi
import scanpy as sc
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import scanpy as sc
import os
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.signal import savgol_filter
# import umap
import anndata as ad
import pandas as pd
from collections import Counter
import anndata
from sklearn.model_selection import train_test_split

from lib_vqvae.vqvae import VQVAE
from lib_vqvae.dataset import *
import lib_vqvae.feature_spectrum as feature_spectrum

import lib_metrics.benchmark as bmk
import lib_metrics.metrics as metrics
from lib_metrics.utils_integration import *

### read your dataset

In [4]:
adata = ad.read_h5ad("./data/liver_test.h5ad")
adata

AnnData object with n_obs × n_vars = 27436 × 43878
    obs: 'cid', 'seq_tech', 'donor_ID', 'donor_gender', 'donor_age', 'donor_status', 'original_name', 'organ', 'region', 'subregion', 'sample_status', 'treatment', 'ethnicity', 'cell_type', 'cell_id', 'study_id'
    var: 'vst.mean', 'vst.variance', 'vst.variance.expected', 'vst.variance.standardized', 'vst.variable'
    obsm: 'umap'

In [5]:
var_names_df = pd.read_csv("./data/hvg5000_gene_names.csv")
gene_names = var_names_df['gene'].tolist()
test_condition = adata.var_names.isin(gene_names)
adata = adata[:,test_condition]
adata

View of AnnData object with n_obs × n_vars = 27436 × 5000
    obs: 'cid', 'seq_tech', 'donor_ID', 'donor_gender', 'donor_age', 'donor_status', 'original_name', 'organ', 'region', 'subregion', 'sample_status', 'treatment', 'ethnicity', 'cell_type', 'cell_id', 'study_id'
    var: 'vst.mean', 'vst.variance', 'vst.variance.expected', 'vst.variance.standardized', 'vst.variable'
    obsm: 'umap'

### load the trained model

In [6]:
# load the saved model
parser = argparse.ArgumentParser()
parser.add_argument("--train_batch_size", type=int, default=64)
parser.add_argument("--eval_batch_size", type=int, default=2000)
parser.add_argument("--encoder_hidden_dim", type=list, default=[1600,1024,800])
parser.add_argument("--decoder_hidden_dim", type=list, default=[800,1024,1600])
parser.add_argument("--codebook_dim", type=int, default=8)
parser.add_argument("--n", type=int, default=32)
parser.add_argument("--n_codebooks", type=int, default=256)
parser.add_argument("--gamma", type=float, default=5)
parser.add_argument("--beta", type=float, default=10)
parser.add_argument('--weight_decay', type=float, default=5e-4)
parser.add_argument("--lr", type=float, default=1e-4) # 1e-3
parser.add_argument('--max_iter', type=int, default=20)
parser.add_argument('--seed', type=int, default=2024)
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument("--dropout_rate", type=float, default=0.2)
parser.add_argument('--mode', type=str, default="vqvae")
parser.add_argument('--hvg', type=int, default=5000)

# whether or not to save model
parser.add_argument("-save", action="store_true", default=True)
parser.add_argument("--log_val", type=int, default=2)

# -- para for paths
parser.add_argument('--save_path', type=str, default='./results/multi_organs_fullgenes/')
parser.add_argument('--test_on', type=str, default='test')

args = parser.parse_args([])

# -- device
if torch.cuda.is_available():
    args.device='cuda'
    torch.cuda.set_device(args.gpu)
else:
    args.device='cpu' 
    
data_dim = 5000
vae = VQVAE(data_dim, args.n, args.codebook_dim, args.n_codebooks, args.encoder_hidden_dim, args.decoder_hidden_dim, args.beta, args.gamma, args.dropout_rate).to(args.device)

vae.load_state_dict(torch.load('./model/model.pth')['model'])
vae.eval() 

VQVAE(
  (encoder): Encoder(
    (network): Sequential(
      (0): Linear(in_features=5000, out_features=1600, bias=True)
      (1): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
      (2): PReLU(num_parameters=1)
      (3): Dropout(p=0.2, inplace=False)
      (4): Linear(in_features=1600, out_features=1024, bias=True)
      (5): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (6): PReLU(num_parameters=1)
      (7): Dropout(p=0.2, inplace=False)
      (8): Linear(in_features=1024, out_features=800, bias=True)
      (9): LayerNorm((800,), eps=1e-05, elementwise_affine=True)
      (10): PReLU(num_parameters=1)
      (11): Dropout(p=0.2, inplace=False)
      (12): Linear(in_features=800, out_features=256, bias=True)
    )
  )
  (vector_quantization): VectorQuantizer(
    (embedding): Embedding(256, 8)
  )
  (decoder): Decoder(
    (network): Sequential(
      (0): Linear(in_features=256, out_features=800, bias=True)
      (1): LayerNorm((800,), eps=1e-05, elementwise

### get the referenced results

In [8]:
new_adata = vae.get_adata_codebook_index(adata, mode=args.mode, device=args.device)



In [9]:
new_adata.obsm['code_index']

array([[187, 187, 188, ...,  98,  25, 223],
       [205,  69, 140, ...,  36, 143,  24],
       [ 33, 187,  36, ..., 117, 143,  24],
       ...,
       [105, 188, 188, ..., 117, 143,  24],
       [135, 123,  81, ...,  36, 143, 223],
       [190,  98, 105, ..., 191,  25,  24]])

In [10]:
new_adata.obsm['latent']

array([[ 4.3742601e-09, -4.0678037e-11, -1.5584246e-13, ...,
        -2.2219011e-10, -5.6158284e-17,  5.5456480e-11],
       [ 7.9681539e-10, -7.4902818e-12, -2.7987551e-14, ...,
        -2.6348945e-10, -7.8751446e-17,  6.5588653e-11],
       [-2.2198401e-09,  2.0596032e-11,  8.0916922e-14, ...,
        -2.6348945e-10, -7.8751446e-17,  6.5588653e-11],
       ...,
       [ 1.5299358e-09, -1.4224762e-11, -5.5298183e-14, ...,
        -2.6348945e-10, -7.8751446e-17,  6.5588653e-11],
       [-8.4892188e-10,  7.7632961e-12,  3.1937282e-14, ...,
        -2.2219011e-10, -5.6158284e-17,  5.5456480e-11],
       [ 1.8610150e-10, -1.8036458e-12, -5.7876406e-15, ...,
        -2.6348945e-10, -7.8751446e-17,  6.5588653e-11]], dtype=float32)