In [1]:
from tqdm import tqdm
from scipy import io
from scipy import sparse
import scipy
import gzip
import scanpy

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import DataLoader, Dataset
from torch.distributions.beta import Beta

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import os
import pickle

In [2]:
# 0 A Census of Immune Cells
# ~ 71 min
DIR_PATH = "/data/home/kimds/Data/A Census of Immune Cells"
census_blood = scanpy.read_loom(DIR_PATH + "/1M-immune-human-blood-10XV2.loom")
census_immune = scanpy.read_loom(DIR_PATH + "/1M-immune-human-immune-10XV2.loom")
census_of_immune_cells_genes = pd.read_csv(DIR_PATH+'/genes.csv')

census_blood.obs_names_make_unique()
census_blood.var_names_make_unique()
census_immune.obs_names_make_unique()
census_immune.var_names_make_unique()

census_raw = sparse.vstack([census_blood.X, census_immune.X]).T

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("var")


In [9]:
# 3 Heart Cell Atlas
# ~ 1 min
DIR_PATH = '/data/home/kimds/Data/Heart Cell Atlas'
heart_raw = io.mmread(DIR_PATH+'/'+'sparse_mtx.mtx')
heart_genes = pd.read_csv(DIR_PATH+'/genes.csv')

In [4]:
# 4 Immune Cell Atlas
# ~ 16 min
DIR_PATH = '/data/home/kimds/Data/Immune Cell Atlas'
immune_raw = io.mmread(DIR_PATH+'/'+'sparse_mtx.mtx')
immune_genes = pd.read_csv(DIR_PATH+'/genes.csv')

In [5]:
# 5 Immune Cells in Critical COVID19
# ~ 60 min
DIR_PATH = '/data/home/kimds/Data/Immune Cells in Critical COVID19'
FILE_PATH = '/data/home/kimds/Data/Immune Cells in Critical COVID19/GSE158055_covid19_counts.mtx.gz'
covid_raw = io.mmread(FILE_PATH)
covid_features = pd.read_csv(DIR_PATH+'/GSE158055_covid19_features.tsv.gz', sep='\t', compression='gzip', header=None)


In [101]:
census_genes = census_blood.var.index

In [5]:
DIR_PATH = '/data/home/kimds/Data/Heart Cell Atlas'
heart_genes = pd.read_csv(DIR_PATH+'/genes.csv')
heart_genes = heart_genes.iloc[:, 0]

In [6]:
DIR_PATH = '/data/home/kimds/Data/Immune Cell Atlas'
immune_genes = pd.read_csv(DIR_PATH+'/genes.csv')
immune_genes = immune_genes.iloc[:, 0]

In [7]:
DIR_PATH = '/data/home/kimds/Data/Immune Cells in Critical COVID19'
covid_features = pd.read_csv(DIR_PATH+'/GSE158055_covid19_features.tsv.gz', sep='\t', compression='gzip', header=None)
covid_genes = covid_features.iloc[:, 0]

In [102]:
temp = census_genes
temp = temp[temp.isin(heart_genes)]
temp = temp[temp.isin(immune_genes)]
temp = temp[temp.isin(covid_genes)]
census_indices = census_genes.isin(temp)
heart_indices = heart_genes.isin(temp)
immune_indices = immune_genes.isin(temp)
covid_indices = covid_genes.isin(temp)

In [105]:
number_of_genes = len(temp)

In [106]:
census = census_raw.tocsc()[census_indices, :]
heart = heart_raw.tocsc()[heart_indices, :]
immune = immune_raw.tocsc()[immune_indices, :]
covid = covid_raw.tocsc()[covid_indices, :]

In [107]:
data = sparse.hstack([census, heart, immune, covid])

<26326x2665987 sparse matrix of type '<class 'numpy.float64'>'
	with 3483512308 stored elements in Compressed Sparse Column format>

In [108]:
io.mmwrite("/data/home/kimds/Data/data.mtx", data)

: 

: 

In [10]:
class CustomDataset(Dataset):
    def __init__(self, mtx):
        super(CustomDataset, self).__init__()
        self.mtx = mtx.tocsc()
        self.number_of_genes, self.number_of_cells = mtx.shape
    def __len__(self):
        return self.number_of_cells

    def __getitem__(self, idx):
        X = torch.FloatTensor(np.asarray(self.mtx[:, idx].todense()).squeeze())        
        return X

In [11]:
def normalize(mtx, C=1e4):
    mtx = mtx.tocsc()
    new_mtx = mtx.astype(np.float64)
    for j in range(len(mtx.indptr)-1):
        index_0 = mtx.indptr[j]
        index_1 = mtx.indptr[j+1]
        new_mtx.data[index_0:index_1] = np.log(C*mtx.data[index_0:index_1]/np.sum(mtx.data[index_0:index_1]+1))
    return new_mtx

In [12]:
# Temp Data
heart = normalize(heart_raw)
number_of_genes = heart.shape[0]

In [13]:
class BetaVAE(nn.Module):
    def __init__(self, number_of_genes):
        super(BetaVAE, self).__init__()

        self.number_of_genes = number_of_genes

        self.encode = nn.Sequential(
            nn.Linear(self.number_of_genes, 1000),
            nn.ReLU(),
            nn.BatchNorm1d(1000),
            nn.Dropout(),
            
            nn.Linear(1000, 100),
            nn.ReLU(),
            nn.BatchNorm1d(100),
            nn.Dropout(),
        )
        self.linear_a = nn.Linear(100, 10)
        self.linear_b = nn.Linear(100, 10)


        self.decode = nn.Linear(10, self.number_of_genes)
        self.decode.weight.data.fill_(0.5)

      

    def forward(self, x):
        # Encoding
        x = self.encode(x)
        a = torch.exp(self.linear_a(x))
        b = torch.exp(self.linear_b(x))

        # Random sampling (reparametrization trick)
        z = Beta(a, b).sample()

        # Decoding
        x_decoded = self.decode(z)

        
        return x_decoded

In [14]:
def train(model, device, train_dataloader, optim, epoch):
    mse = nn.MSELoss()
    softmax = nn.Softmax(dim=1)
    model.train()
    total_loss = 0
    for i, X in enumerate(train_dataloader):
        X = X.to(device)
        optim.zero_grad()
        pred = model.forward(X)
        loss = mse(pred, X)
        total_loss += loss.item()
        loss.backward()
        optim.step()
        
        model.decode.weight.data.copy_(torch.sigmoid(model.decode.weight.data))
    average_loss = total_loss / len(train_dataloader)
    if epoch % 1 == 0:
        print('epoch', epoch)
        print('TRAIN loss =', average_loss*1000)

In [16]:
dataset = CustomDataset(heart)

train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
device = 'cuda:0'
vae = BetaVAE(number_of_genes).to(device)
learning_rate = 0.000002
optimizer = optim.Adam(vae.parameters(), lr=learning_rate)
epochs = 10
for epoch in range(epochs):
    train(vae, device, train_loader, optimizer, epoch)
torch.cuda.empty_cache()

epoch 0
TRAIN loss = 11275.272805552713
epoch 1
TRAIN loss = 11277.879597062423
epoch 2
TRAIN loss = 11275.199492399308
epoch 3
TRAIN loss = 11255.741848445648
epoch 4
TRAIN loss = 11233.660024842931
epoch 5
TRAIN loss = 11213.731784775784
epoch 6
TRAIN loss = 11208.05742520495
epoch 7
TRAIN loss = 11214.231879125366
epoch 8
TRAIN loss = 11175.556101522907
epoch 9
TRAIN loss = 11156.556118262206


In [None]:
class Discriminator(nn.Module):
    def __init__(self, number_of_genes):
        super(Discriminator, self).__init__()
        self.number_of_genes = number_of_genes
        self.model = nn.Sequential(
            nn.Linear(self.number_of_genes, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(),

            nn.Linear(512, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(),

            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(),

            nn.Linear(128, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(),

            nn.Linear(64, 32),
            nn.ReLU(),
            nn.BatchNorm1d(32),
            nn.Dropout(),

            nn.Linear(32, 1)
        )
    
    def forward(self, x):
        x = self.model(x)
        z = torch.sigmoid(x)
        return z
        

In [None]:
def train(epoch, device, train_loader, vae, discriminator, gen_opt, disc_opt):
    vae.train()
    discriminator.train()
    total_g = 0
    total_d = 0

    bce = nn.BCELoss()

    for X in train_loader:
        X = X.to(device)

        size = discriminator(X).size()
        ones = torch.ones(size, requires_grad=False).to(device)
        zeros = torch.zeros(size, requires_grad=False).to(device)

        gen_opt.zero_grad()
        X_ = vae(X)
        gen_loss = bce(discriminator(X_), ones)
        gen_loss.backward()
        total_g += gen_loss.item()
        gen_opt.step()

        disc_opt.zero_grad()        
        actual_loss = bce(discriminator(X), ones)
        fake_loss = bce(discriminator(X_.detach()), zeros)
        disc_loss = (actual_loss + fake_loss) / 2
        disc_loss.backward()
        total_d += disc_loss.item()
        disc_opt.step()
    if epoch % 1 == 0:
        print('epoch', epoch, 'G', total_g/len(train_loader), 'D', total_d/len(train_loader))


In [None]:
immune_dataset = CustomDataset(immune)
dataset = immune_dataset
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

device = 'cuda:4'
learning_rate = 0.01
epochs = 100

vae = BetaVAE(number_of_genes)
vae.load_state_dict(torch.load('/data/home/kimds/GAN/betavae.pt'))
discriminator = Discriminator(number_of_genes)
gen_opt = optim.Adam(vae.parameters(), lr=learning_rate)
disc_opt = optim.Adam(discriminator.parameters(), lr=learning_rate)

for epoch in range(epochs):
    train(epoch, device, train_loader, vae, discriminator, gen_opt, disc_opt)
torch.cuda.empty_cache()