In [None]:
import pandas as pd
from io import StringIO
import scipy
from scipy import sparse
import scvi
from scvi.dataset import GeneExpressionDataset
import numpy
import sparsedat
from sparsedat import wrappers 
from sparsedat import Data_Type
from sparsedat import Sparse_Data_Table as SDT
import scipy
from scvi.models.vae import VAE
from scvi.inference import UnsupervisedTrainer
from scvi.models.scanvi import SCANVI
from plotly import offline as plotly
from sklearn.manifold import TSNE
from plotly import graph_objects
from umap import UMAP
from sklearn.cluster import AgglomerativeClustering
from sklearn.cluster import KMeans
import torch
import scanpy as sc
import pandas as pd
from io import StringIO
from sparsedat import wrappers
from scipy.sparse import coo_matrix, vstack
import itertools
from numpy import load
import random
import os
from scvi import set_seed
import pickle
import anndata
from scrapi.dataset import Gene_Expression_Dataset as GED

In [None]:
SEED = 1040

# Get healthy venous blood PBMC scRNA-seq data from Hashimoto (2019) dataset (https://www.pnas.org/content/116/48/24242#sec-10)

In [None]:
hashimoto_barcodes = pd.read_csv(os.path.join('data', 'Hashimoto2019', 'cell_barcodes.txt'), sep = '\t',header=None)
hashimoto_barcodes = hashimoto_barcodes.drop([2],axis=1)
hashimoto_barcodes.columns = ['barcode','sample_id']

In [None]:
if not os.path.exists(os.path.join("data", "Hashimoto2019", "hashimoto.mtx")):

    # Grabbing the raw counts for study in order to feed into scvi
    sdt = wrappers.load_text(
        os.path.join("data", "Hashimoto2019", "01.UMI.txt"),
        separator="\t",
        has_header=True,
        has_row_names=True,
        default_value=0,
        data_type=Data_Type.INT
    )
    
    sdt.transpose()

    sdt.save(os.path.join("data", "Hashimoto2019", "01.UMI.sdt"))
    
    sdt = SDT(os.path.join("data", "Hashimoto2019", "01.UMI.sdt"))
    
    wrappers.to_mtx(
        sdt,
        os.path.join("data", "Hashimoto2019", "barcodes.txt"),
        os.path.join("data", "Hashimoto2019", "genes.txt"),
        os.path.join("data", "Hashimoto2019", "hashimoto.mtx"),
        column_based=False
    )
else:
    sdt = SDT(os.path.join("data", "Hashimoto2019", "01.UMI.sdt"))

hashimoto_mtx = anndata.read_mtx(os.path.join("data", "Hashimoto2019", "hashimoto.mtx"))
hashimoto_mtx = hashimoto_mtx.X

In [None]:
genes = pd.read_csv(os.path.join("data", "Hashimoto2019", "genes.txt"), sep = '\t',header=None)
gene_ensg_array = genes[0].values

# Get the corresponding values for ensg labels in hashimoto data from hu's data (ensg->gene name mapping)
def get_ensg_gene_names_hu(ensg_array):
    # Load hu dataset to get gene mapping for hoshimoto data
    hu_dataset = sc.read_h5ad(os.path.join("data", "Hu2019", "hu_smith.h5ad"))
    hu_gene_pd = hu_dataset.var

    k = hu_gene_pd
    k = pd.DataFrame(k.values,index=k['ensembl_id'],columns = k.columns)
    k = k.drop(['ensembl_id'],axis=1)
    ensg_gene_array = []
    for ensmbl_id in ensg_array:
        ensg_gene_array.append(k.loc[ensmbl_id]['gene_name'])
    return ensg_gene_array
               
gene_names = get_ensg_gene_names_hu(gene_ensg_array)

genes['gene_name'] = gene_names
genes = genes.rename(columns={0: "ensembl_id"})

In [None]:
sample_ids = [
    "CT1",
    "CT2",
    "CT3",
    "CT4",
    "CT5"
]

study_transcript_counts = []
study_gene_names = []
subject_barcodes = {}
cell_barcodes = []

for subject_index, sample_id in enumerate(sample_ids):
    
    subject_mask = hashimoto_barcodes["sample_id"] == sample_id
    subject_data = hashimoto_mtx[subject_mask.values, :]
    
    cell_barcodes.extend(hashimoto_barcodes["barcode"][subject_mask].values)
    
    study_transcript_counts.append(subject_data)
    study_gene_names.append(genes)
    subject_barcodes["S%i" % (subject_index + 1)] = hashimoto_barcodes["barcode"][subject_mask]

cell_barcodes = numpy.array(cell_barcodes)

In [None]:
ensembl_id_gene_name_lookup = {}

ensembl_id_intersection = None

for study_index, transcript_counts in enumerate(study_transcript_counts):
    
    if ensembl_id_intersection is None:
        ensembl_id_intersection = set(study_gene_names[study_index]["ensembl_id"].values)
    else:
        ensembl_id_intersection = ensembl_id_intersection.intersection(study_gene_names[study_index]["ensembl_id"].values)
        
    for row in study_gene_names[study_index].iterrows():
        ensembl_id_gene_name_lookup[row[1]["ensembl_id"]] = row[1]["gene_name"]

ensembl_id_intersection = list(ensembl_id_intersection)

In [None]:
filtered_study_transcript_counts = []
combined_batch_indices = []

gene_name_index = {gene_name: index for index, gene_name in enumerate(study_gene_names[0]["ensembl_id"].values.tolist())}

for study_index, transcript_counts in enumerate(study_transcript_counts):
    
    gene_indices = []
    
    for gene in ensembl_id_intersection:
        gene_indices.append(gene_name_index[gene])
    
    filtered_study_transcript_counts.append(transcript_counts[:, gene_indices])
    
    combined_batch_indices.extend([study_index]*transcript_counts.shape[0])

combined_transcript_counts = sparse.vstack(filtered_study_transcript_counts)
combined_gene_names = [ensembl_id_gene_name_lookup[ensembl_id] for ensembl_id in ensembl_id_intersection]

In [None]:
existing_gene_name_counts = {}
new_gene_names = []

for gene_index, gene in enumerate(combined_gene_names):
    
    if gene in existing_gene_name_counts:
        existing_gene_name_counts[gene] += 1
        gene = "%s-%i" % (gene, existing_gene_name_counts[gene] + 1)
    else:
        existing_gene_name_counts[gene] = 1
        
    new_gene_names.append(gene)

combined_gene_names = new_gene_names

In [None]:
ged = GeneExpressionDataset()

ged.populate_from_data(
    combined_transcript_counts,
    gene_names=combined_gene_names,
    batch_indices=combined_batch_indices
)

In [None]:
# Initiliaze variationan autoencoder and training parameters
n_epochs = 50
learning_rate = 1e-3
# Save the training weights
latent_pickle_file_name = os.path.join("data", "Hashimoto2019", "hashimoto2019_nepoch_%i_lr_%.1e_latent.pickle" % (n_epochs, learning_rate))
weights_pickle_file_name = os.path.join("data", "Hashimoto2019", "hashimoto2019_nepoch_%i_lr_%.1e_weights.pickle" % (n_epochs, learning_rate))

In [None]:
vae = VAE(ged.nb_genes,n_batch=ged.n_batches)
trainer=UnsupervisedTrainer(vae,ged,train_size=0.8,frequency=1,seed=SEED)

In [None]:
# If you can't load existing latent space, train!
if not os.path.exists(latent_pickle_file_name):

    set_seed(SEED)
    
    trainer.train(n_epochs=n_epochs, lr=learning_rate)
    torch.save(trainer.model.state_dict(), weights_pickle_file_name)
    
    full = trainer.create_posterior(trainer.model, ged, indices=numpy.arange(len(ged)))
    latent, _, _ = full.sequential().get_latent()
    
    with open(latent_pickle_file_name, 'wb') as latent_pickle_file:
        pickle.dump(latent, latent_pickle_file, protocol=pickle.HIGHEST_PROTOCOL)

else:
    
    weights_pickle_file = torch.load(weights_pickle_file_name)
    trainer.model.load_state_dict(weights_pickle_file)
    
    with open(latent_pickle_file_name, 'rb') as latent_pickle_file:
        latent = pickle.load(latent_pickle_file)

In [None]:
tsne = TSNE(n_components = 2,random_state=SEED).fit_transform(latent)

In [None]:
clusters = AgglomerativeClustering(n_clusters=13).fit_predict(latent)

In [None]:
traces = []

for cluster_index in range(clusters.max()+1):
    
    x = tsne[clusters == cluster_index, 0]
    y = tsne[clusters == cluster_index, 1]
    
    trace = graph_objects.Scatter(
        x=x,
        y=y,
        name="Cluster %i" % cluster_index,
        mode="markers"
    )
    
    traces.append(trace)

figure = graph_objects.Figure(traces)

plotly.iplot(figure)

In [None]:
cluster_cell_marker_map = {}
cluster_cell_marker_map[3] = 'CD4 T Cells'
cluster_cell_marker_map[7] = 'CD4 T Cells'

cluster_cell_marker_map[2] = 'CD8 T Cells'
cluster_cell_marker_map[9] = 'CD8 T Cells'

cluster_cell_marker_map[4] = 'NK Cells'
cluster_cell_marker_map[5] = 'NK Cells'

cluster_cell_marker_map[11] = 'B Cells'
cluster_cell_marker_map[12] = 'B Cells'

cluster_cell_marker_map[0] = 'CD14 Monocytes'
cluster_cell_marker_map[10] = 'CD14 Monocytes'

cluster_cell_marker_map[6] = 'CD16 Monocytes'

cluster_cell_marker_map[8] = 'Dendritic Cells'

cluster_cell_marker_map

In [None]:
label_barcodes = {}
    
for cluster, label in cluster_cell_marker_map.items():
    
    cluster_mask = clusters == cluster
    
    cluster_barcodes = cell_barcodes[cluster_mask]
    
    if label not in label_barcodes:
        label_barcodes[label] = set(cluster_barcodes)
    else:
        label_barcodes[label].update(cluster_barcodes)

for subject_id, barcodes in subject_barcodes.items():
    label_barcodes[subject_id] = barcodes
        
GED.write_label_cells_to_file(label_barcodes, os.path.join("data", "Hashimoto2019", "labels.csv"))

In [None]:
# # Note: Uncomment below to explore prominence of a marker in a cluster

# GENE = "CD19"

# gene_index = combined_gene_names.index(GENE)

# traces = []
    
# x = tsne[:, 0]
# y = tsne[:, 1]

# trace = graph_objects.Scatter(
#     x=x,
#     y=y,
#     name="Cluster %i" % cluster_index,
#     mode="markers",
#     marker={
#         "color": combined_transcript_counts[:, gene_index].toarray().flatten()

#     })

# traces.append(trace)

# figure = graph_objects.Figure(traces)

# plotly.iplot(figure)