In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm

import tensorflow_probability as tfp
from tensorflow.keras import backend as K
import pdb

In [25]:
data_directories = ["../DATA/b_cells_filtered_gene_bc_matrices/filtered_matrices_mex/hg19/",
                    "../DATA/cd4_t_helper_filtered_gene_bc_matrices/filtered_matrices_mex/hg19/",
                    "../DATA/cd14_monocytes_filtered_gene_bc_matrices/filtered_matrices_mex/hg19/",
                    "../DATA/cd34_filtered_gene_bc_matrices/filtered_matrices_mex/hg19/",
                    "../DATA/cd56_nk_filtered_gene_bc_matrices/filtered_matrices_mex/hg19/",
                    "../DATA/cytotoxic_t_filtered_gene_bc_matrices/filtered_matrices_mex/hg19/",
                    "../DATA/memory_t_filtered_gene_bc_matrices/filtered_matrices_mex/hg19/",
                    "../DATA/naive_cytotoxic_filtered_gene_bc_matrices/filtered_matrices_mex/hg19/",
                    "../DATA/naive_t_filtered_gene_bc_matrices/filtered_matrices_mex/hg19/",
                    "../DATA/regulatory_t_filtered_gene_bc_matrices/filtered_matrices_mex/hg19/"]
cell_types = ['B_cell','CD4_helper','CD14','CD34','CD56_NK','CD8_cytotoxic','CD4_CD45RO_memory','CD8_CD45RA_naive','CD4_CD45RA_naive','CD4_CD25_regulatory']
bkdata_path = '../DATA/TCGA/TCGA_GDC_HTSeq_Counts.txt'
# gene_list_path = '../DATA/Immune Gene Lists/genes.csv'
data_paths = ['../DATA/TCGA/TCGA_GDC_HTSeq_TPM.csv',
              '../DATA/METABRIC/METABRIC.csv',
              '../DATA/SDY67/SDY67_477.csv',
              '../DATA/Gene Lists/immport_genelist.csv',
              '../DATA/Gene Lists/scdata_genelist_filtered.csv']

In [26]:
def FeatureList(paths: list) -> list:
    features = None
    for path in paths:
        mydata = pd.read_csv(path, index_col = 0)
        if features == None:
            features = set(mydata.index.values.tolist())
        else:
            features = features.intersection(set(mydata.index.values.tolist()))
    features = list(features)
    features.sort()
    return features

class DataPreprocess():
    def __init__(self, datadir, celltypes, bkdata_path, features):
        '''
        Creates preprocessed instance of input data
        scdata should be in matrix.mtx within specified folders along with barcodes.tsv and genes.tsv
        bkdata should have sample names as columns and gene names as rows
        gene_list should have no row or column names/index
        '''
        self.datadir = datadir
        self.celltypes = celltypes
        self.scdata = self.load_scdata(self.datadir, self.celltypes)
        self.bkdata = pd.read_csv(bkdata_path)
        # If there is input gene list, filter out genes not in bkdata or scdata
        if features is None:
            self.features = self.bkdata.index.drop_duplicates()
        else:
            self.features = features
        # Filter out genes not in gene list
        self.scdata = self.scdata[:,self.scdata.var_names.isin(self.features)]
        sc.pp.normalize_total(self.scdata, target_sum=1e6) # normalize to sum to 1,000,000
        # sc.pp.regress_out(scdata, ['total_counts'], n_jobs=1)
        # Transpose, filter out genes not in gene list, then sort column (by gene name)
        self.bkdata = self.bkdata.T
        self.bkdata = self.bkdata.loc[:,self.bkdata.columns.isin(self.features)].sort_index(axis=1)
        self.bkdata = self.bkdata.values.astype(float)
    def load_scdata(self, data_directories, cell_types):
        # Read and merge 10X Genomics scRNA-seq data
        scdata = None
        print('Loading single cell dataset')
        for d, c in zip(tqdm(data_directories), cell_types):
            x = sc.read_10x_mtx(d)
            x.obs['celltype'] = [c]*len(x.obs.index)
            # Change each observation (cell) name to celltype + barcode
            x.obs.set_index(pd.Index([c+'_'+rn[:-2] for rn in x.obs.index]), inplace=True)
            if scdata is not None:
                scdata = ad.concat([scdata, x])
            else:
                scdata = x
        # Filter out cells and genes
        sc.pp.filter_cells(scdata, min_genes=200)
        sc.pp.filter_genes(scdata, min_cells=1)
        # Search for prefix "MT-" (mitochondrial genes) and make new column in variable annotations
        # Search for prefix "RPL/RPS" for ribosomal genes and "MRPL/MRPS" for mitochondrial ribosomal genes
        scdata.var['mito'] = scdata.var.index.str.match('^MT-')
        scdata.var['ribo'] = scdata.var.index.str.startswith(('RPL','RPS'))
        scdata.var['mribo'] = scdata.var.index.str.startswith(('MRPL','MRPS'))
        # Calculate QC metrics as per McCarthy et al., 2017 (Scater)
        sc.pp.calculate_qc_metrics(scdata, qc_vars=['mito','ribo', 'mribo'], inplace=True)
        # Plot QC metrics
        # sns.jointplot(x='total_counts', y='n_genes_by_counts', height=8, data=scdata.obs,
        #     kind='scatter', hue='celltype')
        # sns.jointplot(x='total_counts', y='pct_counts_mito', height=8, data=scdata.obs,
        #     kind='scatter', hue='celltype')
        # sns.jointplot(x='total_counts', y='pct_counts_ribo', height=8, data=scdata.obs,
        #     kind='scatter', hue='celltype')
        # sns.jointplot(x='total_counts', y='pct_counts_mribo', height=8, data=scdata.obs,
        #     kind='scatter', hue='celltype')
        # plt.show()
        # Filter out cells with >5% of counts from mitochondria and mitoribosome
        # scdata = scdata[scdata.obs.pct_counts_ribo > 30, :]
        scdata = scdata[scdata.obs.pct_counts_mito < 5, :]
        scdata = scdata[scdata.obs.pct_counts_mribo < 1, :]
        return scdata
    def __call__(self, whichdata, batch_size=1):
        if whichdata == 'scdata':
            out = []
            print('Dividing single cell dataset into cell types')
            for c in tqdm(self.celltypes):
                scdata_ = self.scdata[self.scdata.obs.celltype==c].to_df().sort_index(axis=1)
                # Add to row index 0 a cell with no gene expression (all zeros)
                # zeros = pd.DataFrame(np.zeros((1,scdata_.shape[1])), columns=scdata_.columns.values)
                # Expand into batch dimension and repeat 2-D tensor by # of samples per mini batch
                # scdata_ = tf.tile(tf.expand_dims(pd.concat([zeros,scdata_]), axis=0), [batch_size,1,1])
                out.append(scdata_)
        elif whichdata == 'bkdata':
            out = self.bkdata
        elif whichdata == 'genelist':
            out = self.features
        else:
            raise ValueError('Choose only one of the following: "scdata", "bkdata", or "genelist"')
        return out
    
class Subsampling(tf.keras.layers.Layer):

    def __init__(self, scdata):
        super(Subsampling, self).__init__()
        # initialize one layer for each cell type
        self.scdata=scdata

    def call(self, inputs):
        # select {number of cells} of random column indices from scdata with uniform probability
        # allows for sampling with replacement (increases variability)
        idx = tf.random.uniform(
            shape=[inputs],
            minval=0, maxval=self.scdata.shape[0]-1,
            dtype=tf.int32)
        # subset scdata with selected random column indices
        subset = tf.gather(self.scdata, idx, axis=1)
        return tf.reduce_sum(subset, axis=1)

class AdversarialSimulator():

    '''
    Ref: github.com/eriklindernoren/Keras-GAN
    '''

    def __init__(self, scdata, n_sim_samples = 1000):
        self.scdata = scdata
        self.n_sim_samples = n_sim_samples
        self.n_celltypes = len(scdata)
        self.n_features = scdata[0].shape[1]
        optmzr = tf.keras.optimizers.Adam(0.0002, 0.5)
        self.Discriminator = self.build_discriminator()
        self.Discriminator.compile(loss='binary_crossentropy',
            optimizer=optmzr,
            metrics=['accuracy'])
        self.Simulator = self.build_simulator()
        # Simulator takes in Nprop as input and generates simbulk
        z = Input(shape=(self.n_celltypes))
        img = self.Simulator(z)
        # For the combined model, we only train the Simulator
        self.Discriminator.trainable = False
        # Discriminator takes simbulk as input and determines validity
        valid = self.Discriminator(self.simbulk)
        # stacked Simulator and Discriminator
        self.AdvSim = tf.keras.Model(z, valid)
        self.AdvSim.compile(loss='binary_crossentropy', optimizer=optmzr)

    def MinMaxNorm(self, x):
        x_scaled = tf.math.divide_no_nan(
            (x - tf.math.reduce_min(x)),
            (tf.math.reduce_max(x) - tf.math.reduce_min(x)))
        return x_scaled

    def simulated_fractions(self, batch_size):
        alpha = [1]*self.n_celltypes
        dist = tfp.distributions.Dirichlet(alpha)
        nprop = dist.sample([batch_size])
        return nprop

    def build_simulator(self):
        inputs = tf.keras.layers.Input(shape=self.n_celltypes)
        x = []
        for c in range(self.n_celltypes):
            x.append(Subsampling(self.scdata[c])(inputs))
        x = tf.keras.layers.Add()(x)
        x = tf.keras.layers.Flatten()(x)
        x = tf.keras.layers.Lambda(lambda x: tf.math.log1p)(x)
        x = tf.keras.layers.LayerNormalization(center=True, scale=True)(x)
        outputs = tf.keras.layers.Lambda(lambda x: self.MinMaxNorm)(x)
        model = tf.keras.Model(inputs, outputs)
        model.summary()
        return model

    def build_discriminator(self):
        inputs = tf.keras.Input(shape=(self.n_features,))
        x = tf.keras.layers.Dense(256)(inputs)
        x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
        x = tf.keras.layers.Dense(128)(x)
        x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
        outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)
        model = tf.keras.Model(inputs, outputs)
        model._name = "Discriminator"
        model.summary()
        return model

    def train(self, X_data, steps=1000):
        X_data = self.MinMaxNorm(tf.math.log1p(X_data))
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))
        for step in range(steps):
            '''
            Train Discriminator
            '''
            # Select random subset of X_data equal to batch_size
            idx = np.random.randint(0, X_data.shape[0], 1)
            bulk = X_data.numpy()[idx]
            # Sample Nprop (cell fractions) using Dirichlet distribution
            nprop = self.simulated_fractions(batch_size)
            # Generate simbulk using Nprop
            simbulk = self.Simulator.predict(nprop)
            # Train Discriminator
            d_loss_real = self.Discriminator.train_on_batch(bulk, valid)
            d_loss_fake = self.Discriminator.train_on_batch(simbulk, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            '''
            Train Simulator
            '''
            # Train Simulator (wants Discriminator to make mistakes)
            s_loss = self.AdvSim.train_on_batch(nprop, valid)
            # Plot the progress
            print ("Step: %d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (step, d_loss[0], 100*d_loss[1], s_loss))


In [20]:
myData = DataPreprocess(data_directories, cell_types, bkdata_path, features)

Loading single cell dataset


100%|██████████| 10/10 [01:58<00:00, 11.73s/it]
  view_to_actual(adata)


In [27]:
myModel = AdversarialSimulator(myData('scdata'))

Dividing single cell dataset into cell types


100%|██████████| 10/10 [00:01<00:00,  8.19it/s]


Model: "Discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 3010)]            0         
_________________________________________________________________
dense (Dense)                (None, 256)               770816    
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 256)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 128)               32896     
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 128)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 129       
Total params: 803,841
Trainable params: 803,841
Non-trainable params: 0
_______________________________________________

AttributeError: 'AdversarialSimulator' object has no attribute 'n_celltype'

In [None]:
myModel.train(myData('bkdata'))

In [None]:
tcga = pd.read_csv(bkdata_path, sep='\t')
tcga = tcga.sort_index()
tcga.to_csv('../DATA/TCGA/TCGA_GDC_HTSeq_TPM.csv')

In [None]:
abis = pd.read_csv('../DATA/GSE107011/GSE107011_Processed_data_TPM.txt')
epic = pd.read_csv('../DATA/EPIC/melanoma_counts.csv')

In [None]:
metabric = pd.read_csv('../DATA/METABRIC/data_expression_median.txt', sep='\t', index_col=0)
metabric.index.name = None
metabric = metabric.iloc[:,1:]
metabric = metabric.sort_index()
metabric.to_csv('../DATA/METABRIC/METABRIC.csv')

In [None]:
sdy67_1 = pd.read_csv('../DATA/SDY67/SDY67_EXP13377_RNA_seq.703318.tsv', sep='\t', index_col=0)
sdy67_2 = pd.read_csv('../DATA/SDY67/SDY67_EXP14625_RNA_seq.703317.tsv', sep='\t', index_col=0)
sdy67_1.index.name = None
sdy67_2.index.name = None
sdy67 = sdy67_1.join(sdy67_2)
sdy67 = sdy67.sort_index()
sdy67_meta = pd.read_csv('../DATA/SDY67/SDY67-DR34_Subject_2_RNA_sequencing_result.txt', sep='\t')
sdy67_meta['SubjectID'] = [sub+'_'+str(time) for sub, time in zip(list(sdy67_meta['Subject Accession'].values),list(sdy67_meta['Study Time Collected'].values))]
sdy67_id = sdy67_meta.reset_index().set_index('Expsample Accession').loc[sdy67.columns.values.tolist(),'SubjectID'].values.tolist()
sdy67.columns = sdy67_id
sdy67.to_csv('../DATA/SDY67/SDY67_477.csv')

sdy67_label = pd.read_csv('../DATA/SDY67/SDY67_extracted_from_mmc7.csv', index_col=0)
sdy67_label.index.name = None
sdy67_label['Other'] = 0
sdy67_label = sdy67_label.fillna(0)
sdy67_label = sdy67_label.loc[sdy67.columns,]
for i in range(len(sdy67_label)):
    sumval = sum(sdy67_label.iloc[i,:])
    if sumval >= 100:
        sdy67_label.iloc[i,:] = sdy67_label.iloc[i,:]/sumval
    else:
        sdy67_label.iloc[i,5] = 100-sumval
        sdy67_label.iloc[i,:] = sdy67_label.iloc[i,:]/100
sdy67_label.to_csv('../DATA/SDY67/SDY67_477_label.csv')

In [None]:
# ABIS dataset need to load genes.tsv then match Ensembl id to gene id 
# EPIC dataset needs fixing
abis = pd.read_csv('../DATA/GSE107011/GSE107011_Processed_data_TPM.txt', index_col=0) # EnsDb.Hsapiens.v79, aggregated signals from duplicate probes by max value

In [None]:
abis[abis.columns[abis.columns.to_series().str.contains('_PBMC')]]

In [None]:
features = set(tcga.index.values.tolist()).intersection(set(metabric.index.values.tolist()))

In [None]:
features = list(features)

In [None]:
features

In [None]:
tcga.T.loc[:,features].head()

In [None]:
def FeatureList(paths: list) -> list:
    features = None
    for path in paths:
        mydata = pd.read_csv(path, header=None)
        if features == None:
            features = set(mydata.index.values.tolist())
        else:
            features = features.intersection(set(mydata.index.values.tolist()))
    features = list(features)
    features.sort()
    return features

In [None]:
class BulkDataset(torch.utils.data.Dataset):
    
    def __init__(self, csv_path, features):
        self.bkdata = pd.read_csv(csv_path, index_col=0)
        self.bkdata = self.bkdata.T.loc[:,features] # when not sorted, add ".sort_index(axis=1)"
    
    def __len__(self):
        return len(self.bkdata)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.bkdata[idx,:]

In [None]:
def load_scdata(data_directories, cell_types):
    # Read and merge 10X Genomics scRNA-seq data
    scdata = None
    for d, c in zip(tqdm(data_directories), cell_types):
        x = sc.read_10x_mtx(d)
        x.obs['celltype'] = [c]*len(x.obs.index)
        # Change each observation (cell) name to celltype + barcode
        x.obs.set_index(pd.Index([c+'_'+rn[:-2] for rn in x.obs.index]), inplace=True)
        if scdata is not None:
            scdata = ad.concat([scdata, x])
        else:
            scdata = x
    # Filter out cells and genes
    sc.pp.filter_cells(scdata, min_genes=200)
    sc.pp.filter_genes(scdata, min_cells=1)
    # Search for prefix "MT-" (mitochondrial genes) and make new column in variable annotations
    # Search for prefix "RPL/RPS" for ribosomal genes and "MRPL/MRPS" for mitochondrial ribosomal genes
    scdata.var['mito'] = scdata.var.index.str.match('^MT-')
    scdata.var['ribo'] = scdata.var.index.str.startswith(('RPL','RPS'))
    scdata.var['mribo'] = scdata.var.index.str.startswith(('MRPL','MRPS'))
    # Calculate QC metrics as per McCarthy et al., 2017 (Scater)
    sc.pp.calculate_qc_metrics(scdata, qc_vars=['mito','ribo', 'mribo'], inplace=True)
    # Plot QC metrics
    # sns.jointplot(x='total_counts', y='n_genes_by_counts', height=8, data=scdata.obs,
    #     kind='scatter', hue='celltype')
    # sns.jointplot(x='total_counts', y='pct_counts_mito', height=8, data=scdata.obs,
    #     kind='scatter', hue='celltype')
    # sns.jointplot(x='total_counts', y='pct_counts_ribo', height=8, data=scdata.obs,
    #     kind='scatter', hue='celltype')
    # sns.jointplot(x='total_counts', y='pct_counts_mribo', height=8, data=scdata.obs,
    #     kind='scatter', hue='celltype')
    # plt.show()
    # Filter out cells with >5% of counts from mitochondria and mitoribosome
    # scdata = scdata[scdata.obs.pct_counts_ribo > 30, :]
    scdata = scdata[scdata.obs.pct_counts_mito < 5, :]
    scdata = scdata[scdata.obs.pct_counts_mribo < 1, :]
    return scdata

In [None]:
scdata = load_scdata(data_directories, cell_types)

In [None]:
len(features)

In [None]:
torch.Tensor(scdata[scdata.obs.celltype=='B_cell',scdata.var_names.isin(features)].to_df().sort_index(axis=1).to_numpy()).shape

In [None]:
scdata_list = []
for c in tqdm(cell_types):
    scdata_list.append(torch.Tensor(scdata[scdata.obs.celltype==c, scdata.var_names.isin(features)].to_df().sort_index(axis=1).to_numpy()))

In [None]:
tmp = torch.nn.Linear(5,10)
tmp = tmp(torch.Tensor([1.,2.,3.,4.,5.]))
dist = torch.distributions.multinomial.Multinomial(total_count=500, logits=tmp)
tmp = dist.sample()

In [None]:
tmp

In [None]:
for i in tmp:
    print(torch.multinomial(torch.Tensor([1]*700), int(i), replacement=False))

In [None]:
torch.distributions.multinomial.Multinomial(total_count=500, probs=torch.Tensor([[1,4,5,4,7,4],[5,3,4,5,6,2]])).sample()

In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import pandas as pd
import seaborn as sns
import os
import tensorflow.keras.backend as K
from tqdm import tqdm
import matplotlib.pyplot as plt
import anndata as ad
import scanpy as sc
from model import *
from data import *

In [None]:
data_directories = ["../DATA/b_cells_filtered_gene_bc_matrices/filtered_matrices_mex/hg19/",
                    "../DATA/cd4_t_helper_filtered_gene_bc_matrices/filtered_matrices_mex/hg19/",
                    "../DATA/cd14_monocytes_filtered_gene_bc_matrices/filtered_matrices_mex/hg19/",
                    "../DATA/cd34_filtered_gene_bc_matrices/filtered_matrices_mex/hg19/",
                    "../DATA/cd56_nk_filtered_gene_bc_matrices/filtered_matrices_mex/hg19/",
                    "../DATA/cytotoxic_t_filtered_gene_bc_matrices/filtered_matrices_mex/hg19/",
                    "../DATA/memory_t_filtered_gene_bc_matrices/filtered_matrices_mex/hg19/",
                    "../DATA/naive_cytotoxic_filtered_gene_bc_matrices/filtered_matrices_mex/hg19/",
                    "../DATA/naive_t_filtered_gene_bc_matrices/filtered_matrices_mex/hg19/",
                    "../DATA/regulatory_t_filtered_gene_bc_matrices/filtered_matrices_mex/hg19/"]
cell_types = ['B_cell','CD4_helper','CD14','CD34','CD56_NK','CD8_cytotoxic','CD4_CD45RO_memory','CD8_CD45RA_naive','CD4_CD45RA_naive','CD4_CD25_regulatory']
bkdata_path = '../DATA/TCGA/TCGA_GDC_HTSeq_Counts.txt'
gene_list_path = '../DATA/Immune Gene Lists/genes.csv'

In [None]:
myData = DataPreprocess(data_directories, cell_types, bkdata_path, gene_list_path)

In [None]:
X = np.expand_dims(np.array(range(100)), axis=1).repeat(100, axis=1)

In [None]:
dns = tf.keras.layers.Dense(10, activation='relu', input_shape=(100,))
x = dns(X)

In [None]:
x[1:3,:]

In [None]:
dml = tfp.layers.DistributionLambda(
    make_distribution_fn=lambda t: tfp.distributions.DirichletMultinomial(
        total_count=500, concentration=t),
    convert_to_tensor_fn=lambda s: s.sample(1)
)

In [None]:
y = dml(x[1,:])

In [None]:
y.sample(1)

In [None]:
scdata = myData('scdata')

In [None]:
tf.reshape(scdata[0], [9899,-1])

In [None]:
scdata0 = tf.data.Dataset.from_tensors(tf.reshape(scdata[0], [9899,-1]))

In [None]:
list(scdata0.shuffle(10000).take(500).as_numpy_iterator())[0].shape

In [None]:
tf.reshape(scdata[0], [9899,-1]).numpy()[list(tf.data.Dataset.range(9899).shuffle(10000).take(500).as_numpy_iterator()),:].shape

In [None]:

idxs = tf.range(tf.shape(inputs)[0])
ridxs = tf.random.shuffle(idxs)[:sample_num]
rinput = tf.gather(inputs, ridxs)

In [None]:
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(10, activation='relu', input_shape=(100,)))
model.add(tfp.layers.DistributionLambda(
    make_distribution_fn=lambda a: tfp.distributions.DirichletMultinomial(
        total_count=500, concentration=a)))