In [None]:
import scanpy as sc
import pandas as pd
from io import StringIO
import anndata
import scipy
from scipy import sparse
import scvi
from scvi.dataset import GeneExpressionDataset
import numpy
from sparsedat import wrappers 
import scipy
from scvi.models.vae import VAE
from scvi.inference import UnsupervisedTrainer
from scvi.models.scanvi import SCANVI
import scanpy as sc
from plotly import offline as plotly
from sklearn.manifold import TSNE
from plotly import graph_objects
from umap import UMAP
from sklearn.cluster import AgglomerativeClustering
from sklearn.cluster import KMeans
import torch
import scanpy as sc
import pandas
from io import StringIO
import anndata
from sparsedat import wrappers
from scipy.sparse import coo_matrix, vstack
import itertools
from numpy import load
import pickle
import os
import random
from scvi import set_seed
import pickle
from scrapi.dataset import Gene_Expression_Dataset as GED

In [None]:
SEED=1040

# Get healthy venous blood PBMC scRNA-seq data from Lee (2020) dataset (https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE149689)

In [None]:
# Extract healthy venous blood PBMC data from study

lee_dataset = sc.read_h5ad(os.path.join("data", "Lee2020", "lee_GSE149689.h5ad"))

lee_obs = lee_dataset.obs

lee_gene_df = lee_dataset.var

lee_mtx = lee_dataset.X

control_subject_descriptors = [
    #S1, age: 63, female
    'Sample 5_Normal 1 scRNA-seq [SW107]',
    #S2, age: 54, female
    'Sample 13_Normal 2 scRNA-seq [SW115]',
    #S3, age: 67, female
    'Sample 14_Normal 3 scRNA-seq [SW116]',
    #S4, age: 64, male
    'Sample 19_Normal 4 scRNA-seq [SW121]'
]

study_transcript_counts = []
study_gene_names = []
subject_barcodes = {}
cell_barcodes = []

for subject_index, subject_descriptor in enumerate(control_subject_descriptors):
    subject_mask = (lee_obs['sample_description']==subject_descriptor)
    subject_data = lee_mtx[subject_mask.values,:]
    
    cell_barcodes.extend(lee_obs["barcode"][subject_mask].values)
    
    study_transcript_counts.append(subject_data)
    study_gene_names.append(lee_gene_df)
    subject_barcodes["S%i" % (subject_index + 1)] = lee_obs["barcode"][subject_mask]

cell_barcodes = numpy.array(cell_barcodes)

In [None]:
# Combine transcripts into a single matrix. 
# Combine gene names (avoid duplicates).
# Get mask for batch indices.
ensembl_id_gene_name_lookup = {}

ensembl_id_intersection = None

for study_index, transcript_counts in enumerate(study_transcript_counts):
    
    if ensembl_id_intersection is None:
        ensembl_id_intersection = set(study_gene_names[study_index]["ensembl_id"].values)
    else:
        ensembl_id_intersection = ensembl_id_intersection.intersection(study_gene_names[study_index]["ensembl_id"].values)
        
    for row in study_gene_names[study_index].iterrows():
        ensembl_id_gene_name_lookup[row[1]["ensembl_id"]] = row[1]["gene_name"]

ensembl_id_intersection = list(ensembl_id_intersection)

filtered_study_transcript_counts = []
combined_batch_indices = []

for study_index, transcript_counts in enumerate(study_transcript_counts):
    
    gene_name_index = {gene_name: index for index, gene_name in enumerate(study_gene_names[0]["ensembl_id"].values.tolist())}
    
    gene_indices = []
    
    for gene in ensembl_id_intersection:
        gene_indices.append(gene_name_index[gene])
    
    filtered_study_transcript_counts.append(transcript_counts[:, gene_indices])
    
    combined_batch_indices.extend([study_index]*transcript_counts.shape[0])

combined_transcript_counts = sparse.vstack(filtered_study_transcript_counts)
combined_gene_names = [ensembl_id_gene_name_lookup[ensembl_id] for ensembl_id in ensembl_id_intersection]

In [None]:
existing_gene_name_counts = {}
new_gene_names = []

for gene_index, gene in enumerate(combined_gene_names):
    
    if gene in existing_gene_name_counts:
        existing_gene_name_counts[gene] += 1
        gene = "%s-%i" % (gene, existing_gene_name_counts[gene] + 1)
    else:
        existing_gene_name_counts[gene] = 1
        
    new_gene_names.append(gene)

combined_gene_names = new_gene_names

In [None]:
ged = GeneExpressionDataset()

ged.populate_from_data(
    combined_transcript_counts,
    gene_names=combined_gene_names,
    batch_indices=combined_batch_indices
)

In [None]:
# Initiliaze variationan autoencoder and training parameters
n_epochs = 50
learning_rate = 1e-3
num_clusters = 15
# Save the training weights
latent_pickle_file_name = os.path.join("data", "Lee2020", "lee2020_nepoch_%i_lr_%.1e_latent.pickle" % (n_epochs, learning_rate))
weights_pickle_file_name = os.path.join("data", "Lee2020", "lee2020_nepoch_%i_lr_%.1e_weights.pickle" % (n_epochs, learning_rate))

In [None]:
vae = VAE(ged.nb_genes,n_batch=ged.n_batches)
trainer=UnsupervisedTrainer(vae,ged,train_size=0.8,frequency=1,seed=SEED)

In [None]:
# If you can't load existing latent space, train!
if not os.path.exists(latent_pickle_file_name):

    set_seed(SEED)
    
    trainer.train(n_epochs=n_epochs, lr=learning_rate)
    torch.save(trainer.model.state_dict(), weights_pickle_file_name)
    
    full = trainer.create_posterior(trainer.model, ged, indices=numpy.arange(len(ged)))
    latent, _, _ = full.sequential().get_latent()
    
    with open(latent_pickle_file_name, 'wb') as latent_pickle_file:
        pickle.dump(latent, latent_pickle_file, protocol=pickle.HIGHEST_PROTOCOL)

else:
    
    set_seed(SEED)
    
    weights_pickle_file = torch.load(weights_pickle_file_name)
    trainer.model.load_state_dict(weights_pickle_file)
    
    full = trainer.create_posterior(trainer.model, ged, indices=numpy.arange(len(ged)))
    
    with open(latent_pickle_file_name, 'rb') as latent_pickle_file:
        latent = pickle.load(latent_pickle_file)

In [None]:
tsne = TSNE(n_components = 2,random_state=SEED).fit_transform(latent)

In [None]:
clusters = AgglomerativeClustering(n_clusters=num_clusters).fit_predict(latent)

In [None]:
traces = []

for cluster_index in range(clusters.max()+1):
    
    x = tsne[clusters == cluster_index, 0]
    y = tsne[clusters == cluster_index, 1]
    
    trace = graph_objects.Scatter(
        x=x,
        y=y,
        name="Cluster %i" % cluster_index,
        mode="markers"
    )
    
    traces.append(trace)

figure = graph_objects.Figure(traces)

plotly.iplot(figure)

In [None]:
traces = []
    
x = tsne[:, 0]
y = tsne[:, 1]

trace = graph_objects.Scatter(
    x=x,
    y=y,
    name="Cluster %i" % cluster_index,
    mode="markers",
    marker={
        "color": numpy.array(combined_transcript_counts.sum(axis=1)).flatten()
    },
    text=numpy.array(combined_transcript_counts.sum(axis=1)).flatten()
)

traces.append(trace)

figure = graph_objects.Figure(traces)

plotly.iplot(figure)

In [None]:
traces = []
    
x = tsne[:, 0]
y = tsne[:, 1]

mt_ratio = numpy.array(combined_transcript_counts[:, numpy.char.startswith(combined_gene_names, "MT-")].sum(axis=1)).flatten()
mt_ratio = mt_ratio/numpy.array(combined_transcript_counts.sum(axis=1)).flatten()

trace = graph_objects.Scatter(
    x=x,
    y=y,
    name="Cluster %i" % cluster_index,
    mode="markers",
    marker={
        "color": mt_ratio
    },
    text=mt_ratio
)

traces.append(trace)

figure = graph_objects.Figure(traces)

plotly.iplot(figure)

In [None]:
GENE = "CD14"

gene_index = combined_gene_names.index(GENE)

traces = []
    
x = tsne[:, 0]
y = tsne[:, 1]

trace = graph_objects.Scatter(
    x=x,
    y=y,
    name="Cluster %i" % cluster_index,
    mode="markers",
    marker={
        "color": combined_transcript_counts[:, gene_index].toarray().flatten()

    })

traces.append(trace)

figure = graph_objects.Figure(traces)

plotly.iplot(figure)

In [None]:
cluster_cell_marker_map = {}


cluster_cell_marker_map[3] = 'CD14 Monocytes'
cluster_cell_marker_map[7] = 'CD14 Monocytes'

cluster_cell_marker_map[13] = 'CD16 Monocytes'

cluster_cell_marker_map[2] = 'B Cells'
cluster_cell_marker_map[4] = 'NK Cells'

cluster_cell_marker_map[10] = 'CD8 T Cells'
cluster_cell_marker_map[6] = 'CD4 T Cells'

# cluster_cell_marker_map[0] = "Dendritic Cells"
#cluster_cell_marker_map[11] = "Red Blood Cells"
#cluster_cell_marker_map[1] = "Debris"
#cluster_cell_marker_map[5] = "Debris"
#cluster_cell_marker_map[9] = "Debris"
#cluster_cell_marker_map[14] = "Debris"

#cluster_cell_marker_map[2] = "Dendritic Cells"
#cluster_cell_marker_map[3] = "Debris"


In [None]:
label_barcodes = {}
    
for cluster, label in cluster_cell_marker_map.items():
    
    cluster_mask = clusters == cluster
    
    cluster_barcodes = cell_barcodes[cluster_mask]
    
    if label not in label_barcodes:
        label_barcodes[label] = set(cluster_barcodes)
    else:
        label_barcodes[label].update(cluster_barcodes)

for subject_id, barcodes in subject_barcodes.items():
    label_barcodes[subject_id] = barcodes
        
GED.write_label_cells_to_file(label_barcodes, os.path.join("data", "Lee2020", "labels.csv"))