In [1]:
import csv
from scipy import sparse
from tqdm import tqdm
import os
import pandas as pd
import numpy as np
base_path = '/bigstore/binfo/mouse/Brain/Sequencing/Allen_10X_SmartSeqTypes/'
genes = list(pd.read_csv(os.path.join(base_path,'dredfish_filtered_genes.csv'),index_col=0,header=None).index)
cells = list(pd.read_csv(os.path.join(base_path,'sorted_cells.csv'),index_col=0,header=None).index)
metadata = pd.read_csv(os.path.join(base_path,'sorted_metadata.csv'),index_col=0).loc[cells]
labels = list(metadata['cell_type_alias_label'])
counts = np.load(os.path.join(base_path,'dredfish_normcounts.npy'))

  interactivity=interactivity, compiler=compiler, result=result)


In [2]:
# # Generate Means and Covariance Matrix for each label
# gmean = np.empty((len(unqtypes),len(genes)))
# gcov = np.empty((len(unqtypes),len(genes),len(genes)),dtype='float16')
# for i,ct in tqdm(enumerate(unqtypes),total=len(unqtypes)):
#     ct_indexes = np.where(np.array(labels)==ct)[0]
#     ct_counts = counts[ct_indexes,:]
#     ct_means = np.mean(ct_counts,axis=0)
#     ct_cov = np.cov(ct_counts.T)
#     gmean[i,:] = ct_means
#     gcov[i,:,:] = ct_cov

In [3]:
# Class Balance
unqtypes = np.unique(labels)
n = 100
class_balanced = np.empty((len(unqtypes)*n,len(genes)))
balanced_labels = np.empty((len(unqtypes)*n)).astype(str)
for i,ct in tqdm(enumerate(unqtypes),total=len(unqtypes)):
    start = i*n
    end = (i+1)*n
    ct_indexes = np.where(np.array(labels)==ct)[0]
    if len(ct_indexes)>n:
        ct_indexes = np.random.choice(ct_indexes,n,replace=False)
    else:
        ct_indexes = np.random.choice(ct_indexes,n,replace=True)
    balanced_labels[start:end] = ct
    class_balanced[start:end,:] = counts[ct_indexes,:]

100%|██████████| 374/374 [01:25<00:00,  4.36it/s]


In [4]:
# Dimensional Reduce
from sklearn.decomposition import PCA
nbits = 12
pca = PCA(n_components=nbits)
pca_out = pca.fit(class_balanced)
nonpositive_loadings = pca_out.components_.T
loadings = np.empty((len(genes),2*nbits))
loadings[:,0:nbits] = nonpositive_loadings
loadings[:,nbits:2*nbits] = -nonpositive_loadings
loadings[loadings<0] = 0
compressed = np.matmul(counts,loadings).astype('float16')

In [56]:
# Generate Means and Covariance Matrix for each label
from scipy.stats import multivariate_normal
# nbits = 24
# csmean = np.empty((len(unqtypes),nbits))
# cscov = np.empty((len(unqtypes),nbits,nbits),dtype='float16')
likelihood = np.empty((len(cells),len(unqtypes)))
for i,ct in tqdm(enumerate(unqtypes),total=len(unqtypes)):
    ct_indexes = np.where(np.array(labels)==ct)[0]
    ct_compressed = compressed[ct_indexes,:]
    ct_means = np.mean(ct_compressed,axis=0)
    ct_cov = np.cov(ct_compressed.T)
    likelihood[:,i] = multivariate_normal.pdf(compressed,mean=ct_means, cov=ct_cov,allow_singular=True)
#     csmean[i,:] = ct_means
#     cscov[i,:,:] = ct_cov
max_likelihood = np.max(likelihood,axis=1)

100%|██████████| 374/374 [05:43<00:00,  1.09it/s]


In [69]:
labels = np.array(labels)
prior = np.array([len(labels[(labels==ct)])/len(labels) for ct in tqdm(unqtypes)])


100%|██████████| 374/374 [00:08<00:00, 44.15it/s]


In [91]:
log_prior = np.log10(prior)
log_likelihood = np.log10(likelihood)


  


In [92]:
posterior = log_likelihood+log_prior

In [95]:
predicted_idx = np.argmax(posterior,axis=1)

In [96]:
predicted_ct = unqtypes[predicted_idx]

In [97]:
predicted_ct[0]

'183_L2/3 IT CTX'

In [104]:
ct = unqtypes[0]
a = 0
for ct in unqtypes:
    true = len(labels[labels==ct])
    accurate = len(predicted_ct[(labels==ct)&(predicted_ct==ct)])
    a+=accurate
    print(ct,round(100*accurate/true,4))

100_Sst 0.9524
101_Sst 1.4451
102_Sst 3.9537
103_Sst 3.5294
104_Sst 0.0
105_Pvalb 0.0
106_Pvalb 0.0
107_Pvalb 0.0
108_Pvalb 0.0
109_Pvalb 0.0
10_Lamp5 43.1293
110_Pvalb 3.0612
111_Pvalb 0.091
112_Pvalb 1.6239
113_Pvalb 0.6359
114_Pvalb 8.008
115_Pvalb 0.2276
116_Pvalb 4.2283
117_Pvalb 0.5018
118_Pvalb 0.0
119_Pvalb Vipr2 1.1628
11_Lamp5 40.4119
120_DG 92.9412
121_DG 92.5
122_DG 79.6748
123_DG 96.1538
124_DG 61.0082
125_DG 97.5513
126_L2/3 IT APr 68.2266
127_L2/3 IT APr 75.4286
128_L2/3 IT APr 67.5227
129_L2/3 IT APr 76.9883
12_Lamp5 46.4309
130_L2 IT RSPv 20.7265
131_L2 IT RSPv 73.5391
132_L2/3 IT PPP 73.042
133_L2/3 IT PPP 56.1676
134_L2/3 IT PPP 80.8611
135_L2/3 IT PPP 52.0216
136_L2/3 IT PPP 84.2466
137_L2 IT ENTl 49.5652
138_L2 IT ENTl 91.8919
139_L2 IT ENTl 89.2397
13_Lamp5 27.9832
140_L3 IT ENTm 69.0501
141_L3 IT ENTm 71.6088
142_L3 IT ENTm 41.194
143_L3 IT ENTm 85.9289
144_L3 IT ENTl 27.4945
145_L3 IT ENTl 74.4526
146_L3 IT ENTl 55.8967
147_L2 IT PAR 52.9801
148_L2 IT PAR 89.857

In [105]:
a/len(labels)

0.6291459557575642