In [1]:
import anndata as ad
import scanpy as sc

import numpy as np
import pandas as pd
import sklearn as sk
import matplotlib.pyplot as plt
import torch

from persist import PERSIST, ExpressionDataset

In [3]:
adata = ad.read_h5ad('/scratch/nmq407/dvc_neurons.h5Seurat.h5ad')
adata

AnnData object with n_obs × n_vars = 121868 × 10000
    obs: 'nCount_RNA', 'nFeature_RNA', 'percent.mt', 'pool', 'hash.ID', 'treatment', 'run', 'species', 'area', 'orig.clusters', 'zhang.predictions', 'zhang.score', 'ludwig.predictions', 'ludwig.score', 'integrated_snn_res.0.1', 'integrated_snn_res.1', 'sub.cluster', 'neurotransmitter', 'cell.type', 'major.cell.type', 'nCount_rat_RNA', 'nFeature_rat_RNA', 'nCount_SCT', 'nFeature_SCT'
    var: 'features'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'

In [4]:
sc.pp.highly_variable_genes(adata, flavor='seurat', n_top_genes=10000, inplace=True)

In [5]:
adata

AnnData object with n_obs × n_vars = 121868 × 10000
    obs: 'nCount_RNA', 'nFeature_RNA', 'percent.mt', 'pool', 'hash.ID', 'treatment', 'run', 'species', 'area', 'orig.clusters', 'zhang.predictions', 'zhang.score', 'ludwig.predictions', 'ludwig.score', 'integrated_snn_res.0.1', 'integrated_snn_res.1', 'sub.cluster', 'neurotransmitter', 'cell.type', 'major.cell.type', 'nCount_rat_RNA', 'nFeature_rat_RNA', 'nCount_SCT', 'nFeature_SCT'
    var: 'features', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'hvg'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'

In [6]:
adata.var['highly_variable']

Ccdc83      True
Cnksr3      True
Neurl1      True
Gsto1       True
Mxi1        True
            ... 
Mmp11       True
Ppm1n       True
Bex4        True
Sfxn2.1     True
Anp32a.1    True
Name: highly_variable, Length: 10000, dtype: bool

In [7]:
adata = adata[:,adata.var['highly_variable']]

In [8]:
adata.obs['ludwig.predictions']

SI-TT-A8_CCTCACACATGGCCCA_2    GABA3
SI-TT-A8_CTTTCGGCAGGACTTT_2    GABA3
SI-TT-A8_CCACACTGTGGCCTCA_2    Chat2
SI-TT-A8_GAACTGTAGGCCCACT_2    GABA3
SI-TT-A8_TTCACCGAGGCCTTGC_2    Chat2
                               ...  
SI-TT-H5_TACCCGTCAATCGTCA_2    GABA1
SI-TT-H5_AGATGAATCTAATTCC_2    GABA2
SI-TT-H5_GTAGATCCATCGATAC_2    GABA1
SI-TT-H5_ATTACTCAGAGCATTA_2    GABA1
SI-TT-H5_AGCGTCGCAAGCGATG_2    GABA1
Name: ludwig.predictions, Length: 121868, dtype: object

In [9]:
adata.obs['ludwig.predictions_codes'] = pd.Categorical(adata.obs['ludwig.predictions']).codes


  adata.obs['ludwig.predictions_codes'] = pd.Categorical(adata.obs['ludwig.predictions']).codes
  next(self.gen)


In [10]:
adata

AnnData object with n_obs × n_vars = 121868 × 10000
    obs: 'nCount_RNA', 'nFeature_RNA', 'percent.mt', 'pool', 'hash.ID', 'treatment', 'run', 'species', 'area', 'orig.clusters', 'zhang.predictions', 'zhang.score', 'ludwig.predictions', 'ludwig.score', 'integrated_snn_res.0.1', 'integrated_snn_res.1', 'sub.cluster', 'neurotransmitter', 'cell.type', 'major.cell.type', 'nCount_rat_RNA', 'nFeature_rat_RNA', 'nCount_SCT', 'nFeature_SCT', 'ludwig.predictions_codes'
    var: 'features', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'hvg'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'

In [11]:
adata.layers['bin'] = (adata.X>0).astype(np.float32)

In [12]:
print(adata)

AnnData object with n_obs × n_vars = 121868 × 10000
    obs: 'nCount_RNA', 'nFeature_RNA', 'percent.mt', 'pool', 'hash.ID', 'treatment', 'run', 'species', 'area', 'orig.clusters', 'zhang.predictions', 'zhang.score', 'ludwig.predictions', 'ludwig.score', 'integrated_snn_res.0.1', 'integrated_snn_res.1', 'sub.cluster', 'neurotransmitter', 'cell.type', 'major.cell.type', 'nCount_rat_RNA', 'nFeature_rat_RNA', 'nCount_SCT', 'nFeature_SCT', 'ludwig.predictions_codes'
    var: 'features', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'hvg'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    layers: 'bin'


In [13]:
# Choose training and validation splits. 
# You may want to use a different strategy to choose these - see https://scikit-learn.org/stable/modules/classes.html#module-sklearn.model_selection
train_ind, val_ind = sk.model_selection.train_test_split(np.arange(adata.shape[0]), train_size=0.8)

print(f'{adata.shape[0]} total samples')
print(f'{np.size(train_ind)} in training set')
print(f'{np.size(val_ind)} in validation set')

# These are views, so they do not take up memory
adata_train = adata[train_ind,:]
adata_val = adata[val_ind,:]

121868 total samples
97494 in training set
24374 in validation set


In [None]:
import time
# Get the start time
start_time = time.time()
print(start_time)

# Initialize the dataset for PERSIST
# Note: Here, data_train.layers['bin'] is a sparse array
# data_train.layers['bin'].A converts it to a dense array
train_dataset = ExpressionDataset(adata_train.layers['bin'], adata_train.obs['ludwig.predictions_codes'])
val_dataset = ExpressionDataset(adata_val.layers['bin'], adata_val.obs['ludwig.predictions_codes'])


# Use GPU device if available -- we highly recommend using a GPU!
device = torch.device(torch.cuda.current_device() if torch.cuda.is_available() else 'cpu')

# Number of genes to select within the current selection process.
num_genes = (32, 64, 100)
persist_results = {}

# Set up the PERSIST selector
selector = PERSIST(train_dataset,
                   val_dataset,
                   loss_fn=torch.nn.CrossEntropyLoss(),
                   device=device)
print(device)

# Coarse removal of genes
print('Starting initial elimination...')
candidates, model = selector.eliminate(target=500, max_nepochs=250)
print('Completed initial elimination.')

print('Selecting specific number of genes...')
for num in num_genes:
    inds, model = selector.select(num_genes=num, max_nepochs=250)
    persist_results[num] = inds
print('Done')

# Get the end time
end_time = time.time()
print(time.localtime(end_time))
# Calculate the execution time
execution_time = end_time - start_time

# Format the execution time in a human-readable format
minutes, seconds = divmod(execution_time, 60)
hours, minutes = divmod(minutes, 60)
formatted_time = f"{int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds"
print("Execution time:", formatted_time)

1682376696.1420925
cuda:0
Starting initial elimination...
using CrossEntropyLoss, starting with lam = 0.0001


Training epochs:   0%|          | 0/250 [00:00<?, ?it/s]

In [None]:
device

In [None]:
persist_results

In [None]:
minutes, seconds = divmod(end_time, 60)
hours, minutes = divmod(minutes, 60)
formatted_time = f"{int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds"

In [None]:
formatted_time

In [None]:
time.localtime(end_time)

In [None]:
# obtain a copy of features from the anndata object
# Note: Without the .copy(), you will modify adata itself, which may be desirable in some use cases.
df = adata.var.copy()

# set a boolean = True for genes selected in any of the rounds
for num in num_genes:
    df[f'persist_set_{num}'] = False
    ind = df.iloc[persist_results[num]].index
    df.loc[ind,f'persist_set_{num}'] = True

In [None]:
# only keep features (genes) that were selected in any set by PERSIST, and save for subsequent use
df = df[df[[f'persist_set_{num}' for num in num_genes]].any(axis=1)]

df.head(2)

In [None]:
sc.pl.dotplot(adata,
              var_names=df[df['persist_set_100']].index.values,
              groupby='polar_label',
              layer='bin', )
plt.show()

In [None]:
df[df['persist_set_100']].index.values

In [None]:
gbr_100 = ['Nrg3', 'Plp1', 'Aqp4', 'Htr3b', 'Rax', 'Rbfox1', 'Il1rapl2', 'Robo1', 'Kcnip4', 'Sgcz', 'Fgf13', 'Cntn4', 'Egfem1', 'Gpc6', 'Hdac9', 'Col25a1', 'Dcc', 'Nkain2', 'Pcdh11x', 'Prkg1', 'Pdgfra', 'Prr16', 'Kctd16', 'Nrxn3', 'Pde10a', 'Plcl1', 'Nrg1', 'Rtl4', 'Bmp4', 'Grm7', 'Ptprk', 'Sgcd', 'Ncam2', 'Zfhx3', 'Erbb4', 'Kirrel3', 'Nxph1', 'Mgat4c', 'Oxr1', 'Sorcs1', 'Pcdh7', 'Adarb2', 'Csmd1', 'Inpp4b', 'Ptprt', 'Trpm3', 'Tenm3', 'Tox', 'Alcam', 'Car10', 'Ntm', 'Slit2', 'Plxdc2', 'Tenm2', 'Luzp2', 'Ptprd', 'Lrp1b', 'Sox5', 'Brinp3', 'Pcdh9', 'Cacna2d3', 'Rmst', 'Rgs6', 'Fat3', 'Cadm2', 'Pde4b', 'Gpc5', 'Cfap299', 'Gria1', 'Arhgap6', 'Pdzrn3', 'P3h2', 'Kcnq3', 'Plcb1', 'Deptor', 'Agrp', 'Pomc', 'Lef1', 'Lmx1a', 'Cntn5', 'Lingo2', 'Zfp804b', 'Galntl6', 'Grm8', 'Hs3st4', 'Fhit', 'Immp2l', 'Lrmda', 'Macrod2', 'Gtdc1', 'Naaladl2', 'Nalf1', 'Slc1a2', 'Ctnna2', 'Slc7a11', 'Prkca', 'Dlg2', 'Gabrg3']

In [None]:
len(gbr_100)

In [None]:
sc.pl.dotplot(adata,
              var_names=gbr_100,
              groupby='polar_label',
              layer='bin')
plt.show()

In [None]:
len(selector.candidates)

In [None]:
selector.loss_fn

In [None]:
selector.preselected

In [None]:
selector.preselected_relative

In [None]:
selector.activation

In [None]:
selector.