In [44]:
!git clone https://github.com/sbc806/cpsc-522-project.git

fatal: destination path 'cpsc-522-project' already exists and is not an empty directory.


In [45]:
import sys as sys

sys.path.append("cpsc-522-project")

In [46]:
!pip install scanpy



In [47]:
!pip3 install bbknn



In [48]:
import numpy as np
import os as os
import scanpy as sc
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from collections import defaultdict

from hyperspherical_vae.distributions import VonMisesFisher
from hyperspherical_vae.distributions import HypersphericalUniform

In [49]:
sc.settings.verbosity = 3  # verbosity: errors (0), warnings (1), info (2), hints (3)
sc.logging.print_header()
sc.settings.set_figure_params(dpi=80, facecolor='white')

scanpy==1.9.3 anndata==0.9.1 umap==0.5.3 numpy==1.23.5 scipy==1.10.0 pandas==1.5.3 scikit-learn==1.2.0 statsmodels==0.13.5 pynndescent==0.5.8


In [50]:
def read_data(filename):
    adata = sc.read_h5ad(filename)
    return adata

In [51]:
file_path = "./data"
# file_path = "./drive/MyDrive/immune_cell_dataset/"
# data can be myeloid, b_cells, or t_cells
data = "myeloid"
# file_path = os.path.join(file_path, data)
if data == "myeloid":
    filename = "myeloid.h5ad"
elif data == "b_cells":
    filename = "b-cells.h5ad"
else:
    filename = "t-cells.h5ad"
file_path = os.path.join(file_path, filename)
# file_path = os.path.join(file_path, filename)

adata = read_data(file_path)

Data preprocessing

In [52]:
# Data already normalized and log-transformed
# Select highly variable genes
sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
adata = adata[:, adata.var.highly_variable]
print(adata.X.shape)
# Scale to unit variance and zero mean
sc.pp.scale(adata)

extracting highly variable genes
    finished (0:00:02)
--> added
    'highly_variable', boolean vector (adata.var)
    'means', float vector (adata.var)
    'dispersions', float vector (adata.var)
    'dispersions_norm', float vector (adata.var)
(51552, 1816)


  view_to_actual(adata)


... as `zero_center=True`, sparse input is densified and may lead to large memory consumption


Get rid of batch effects

In [53]:
import bbknn

bbknn.ridge_regression(adata, batch_key=['Chemistry'])

computing ridge regression
	finished: `.X` now features regression residuals
	`.layers['X_explained']` stores the expression explained by the technical effect (0:00:00)


Model

In [54]:

# noinspection PyUnresolvedReferences
# noinspection PyCallingNonCallable
class ProductSpaceVAE(torch.nn.Module):

    # def __init__(self, h_dims, z_dims, input_size=[1, 28, 28], input_type = 'binary', distribution='normal',
    # r=None, encode_type='mlp', decode_type='mlp', device='cpu', flags=None):
    def __init__(self, z_dims, n_gene, distribution='normal',
                 r=None, encoder_layer=None, decoder_layer=None, activation=F.relu,
                 device='cpu', flags=None):
        """
        ModelVAE initializer
        # :param in_dim: dimension of input
        :param n_gene: dimension of input
        # :param h_dims: dimension of the hidden layers, list
        :param z_dims: dimensions of the latent representation, list
        :param distribution: string either `normal` or `vmf`, indicates which distribution to use
        :param r: radii scalars, list
        :encoder_layer: a list with the units of each layer for the encoder
        :decoder_layer: a list with units of each layer for the decoder
        :param device: device to use
        """
        super(ProductSpaceVAE, self).__init__()

        self.flags = flags
        self.name = 'productspace'
        self.epochs, self.num_restarts = 0, 0
        # self.input_size, self.distribution, self.device = input_size, distribution, device
        self.n_gene, self.distribution, self.device = n_gene, distribution, device
        # self.encode_type, self.decode_type = encode_type, decode_type
        self.activation = activation

        self.z_dims = np.sort(np.asarray(z_dims))
        self.z_unique, self.z_counts = np.unique(self.z_dims, return_counts=True)
        self.z_u_idx = [np.where(self.z_dims == u)[0] for u in self.z_unique]

        self.r = torch.ones(len(z_dims), device=device) if r is None else r  # not used yet

        # self.encoder, self.fc_means, self.fc_vars = create_encoder(input_size, input_type, self.z_dims, h_dims,
        # distribution, encode_type, flags)
        if encoder_layer is None:
            encoder_layer = [128, 64, 32]
        if decoder_layer is None:
            decoder_layer = [32, 128]

        # Output of the encoder
        h_last = encoder_layer[-1]
        self.fc_means = nn.ModuleList([nn.Linear(h_last, z_dim) for z_dim in z_dims])
        self.fc_vars = nn.ModuleList([nn.Sequential(nn.Linear(h_last, (1 if distribution == 'vmf' else z_dim)),
                                                    nn.Softplus(), nn.Hardtanh(min_val=0.01, max_val=7.))
                                      for z_dim in z_dims])
        # Output of decoder
        self.mu_layer = nn.Linear(decoder_layer[-1], n_gene)
        self.var_layer = nn.Linear(decoder_layer[-1], n_gene)

        # self.decoder = create_decoder(input_size, input_type, z_dims, h_dims, decode_type)

        # Create layers for encoder
        self.encoder_layer = [nn.Linear(n_gene, encoder_layer[0])]
        # BatchNorm
        # self.encoder_batch_norm = [nn.BatchNorm1d(encoder_layer[0])]
        for i in range(1, len(encoder_layer)):
            self.encoder_layer.append(nn.Linear(encoder_layer[i - 1], encoder_layer[i]))
            # BatchNorm
            # self.encoder_batch_norm.append(nn.BatchNorm1d(encoder_layer[i]))
        self.encoder_layer = nn.ModuleList(self.encoder_layer)
        # BatchNorm
        # self.encoder_batch_norm = nn.ModuleList(self.encoder_batch_norm)

        # self.decoder_layer = [nn.Linear(z_dim, decoder_layer[0])]
        self.decoder_layer = [nn.Linear(sum(z_dims), decoder_layer[0])]
        # BatchNorm
        # self.decoder_batch_norm = [nn.BatchNorm1d(decoder_layer[0])]
        for i in range(1, len(decoder_layer)):
            self.decoder_layer.append(nn.Linear(decoder_layer[i - 1], decoder_layer[i]))
            # BatchNorm
            # self.decoder_batch_norm.append(nn.BatchNorm1d(decoder_layer[i]))
        self.decoder_layer = nn.ModuleList(self.decoder_layer)
        # BatchNorm
        # self.decoder_batch_norm = nn.ModuleList(self.decoder_batch_norm)

    @staticmethod
    def _print(x):
        print('\n')
        print(x)
        print('\n')

    def encode(self, x):
        # x = self.encoder_mlp(x)

        # if self.encode_type == 'cnn':
        # x = x.reshape(x.size(0), *self.input_size)

        # h = self.encoder(x)
        # h = h.view(h.size(0), -1)

        # regularizer = torch.

        h = self.encoder_layer[0](x)
        h = self.activation(h)
        # Add in batch normalization here
        # h = self.encoder_batch_norm[0](h)
        for i in range(1, len(self.encoder_layer)):
            h = self.encoder_layer[i](h)
            h = self.activation(h)
            # Add in batch normalization here
            # h = self.encoder_batch_norm[i](h)

        if self.distribution == 'normal':
            # compute means and stds of the normal distributions
            z_means = [f(h) for f in self.fc_means]
            z_vars = [f(h) for f in self.fc_vars]
        elif self.distribution == 'vmf':
            # compute means and concentrations of the von Mises-Fishers
            z_means_unnormalized = [f(h) for f in self.fc_means]
            z_means = [zmu / zmu.norm(dim=-1, keepdim=True) for zmu in z_means_unnormalized]
            z_vars = [(f(h) + 1.) for f in self.fc_vars]  # the `+ 1` prevents collapsing behaviors
        else:
            raise NotImplemented

        return z_means, z_vars

    def decode(self, z):
        # if self.decode_type == 'cnn':
        # z = z.view(z.size(0), sum(self.z_dims), 1, 1)

        # x_recon = self.decoder(z)

        # return x_recon.view(x_recon.size(0), -1)
        # l2 regularization goes here
        h = self.decoder_layer[0](z)
        h = self.activation(h)
        # Add in batch normalization here
        # h = self.decoder_batch_norm[0](h)
        for i in range(1, len(self.decoder_layer)):
            h = self.decoder_layer[i](h)
            h = self.activation(h)
            # Add in batch normalization here
            # h = self.decoder_batch_norm[i](h)

        mu = self.mu_layer(h)
        sigma_square = F.softplus(self.var_layer(h))
        sigma_sqare = torch.clip(sigma_square, 1e-6, 1e10)
        return mu, sigma_square

    def reparameterize(self, z_means, z_vars):

        # since z is sorted, we take the min index, and the max index, to slice the list of z_means, z_vars
        # this is done to not have convert to numpy array
        gather_zvs = [(torch.cat(z_means[min(u_idx):max(u_idx) + 1], 0),
                       torch.cat(z_vars[min(u_idx):max(u_idx) + 1], 0))
                      for u_idx in self.z_u_idx]

        if self.distribution == 'normal':
            # for each pair of z_mean, z_var, we make a distribution (sampling) object
            q_zs_sample = [torch.distributions.normal.Normal(z_mean, z_var) for (z_mean, z_var) in gather_zvs]

            q_zs = [torch.distributions.normal.Normal(z_mean, z_var) for z_mean, z_var in zip(z_means, z_vars)]
            p_zs = [torch.distributions.normal.Normal(torch.zeros_like(z_mean), torch.ones_like(z_var)) for
                    z_mean, z_var in zip(z_means, z_vars)]

        elif self.distribution == 'vmf':
            # for each pair of z_mean, z_var, we make a distribution (sampling) object
            q_zs_sample = [VonMisesFisher(z_mean, z_var) for (z_mean, z_var) in gather_zvs]

            q_zs = [VonMisesFisher(z_mean, z_var) for z_mean, z_var in zip(z_means, z_vars)]
            p_zs = [HypersphericalUniform(z_dim - 1, device=self.device) for z_dim in self.z_dims]
        else:
            raise NotImplemented

        return q_zs, p_zs, q_zs_sample

    # def loss(self, q_zs, p_zs, x_mb, x_mb_recon):
    def kl_divergence(self, q_zs, p_zs):
        # if self.flags.loss_function == 'bce':
        # lf = nn.BCEWithLogitsLoss(reduction='none')
        # elif self.flags.loss_function == 'mse':
        # lf = nn.MSELoss(reduction='none')
        # else:
        # raise NotImplemented
        # loss_recon = lf(x_mb_recon, x_mb.reshape(x_mb.size(0), -1)).sum(-1).mean()

        if self.distribution == 'normal':
            loss_kl = torch.stack([torch.distributions.kl.kl_divergence(q_z, p_z).sum(-1) for
                                   q_z, p_z in zip(q_zs, p_zs)], dim=-1).sum(-1).mean()
        elif self.distribution == 'vmf':
            loss_kl = torch.stack([torch.distributions.kl.kl_divergence(q_z, p_z) for
                                   q_z, p_z in zip(q_zs, p_zs)], dim=-1).sum(-1).mean()
        else:
            raise NotImplemented

        # return loss_recon, loss_kl, None
        return loss_kl

    # def log_likelihood(self, x, n=10):
    # """
    # :param x: e.g. MNIST data flattened
    # :param n: number of MC samples
    # :return: MC estimate of log-likelihood
    # """

    # z_means, z_vars = self.encode(x.reshape(x.size(0), -1))
    # q_zs, p_zs, _, = self.reparameterize(z_means, z_vars)
    # z_parts = [q_z.rsample(torch.Size([n])) for q_z in q_zs]
    # z = torch.cat(z_parts, dim=-1).reshape(n*x.size(0), -1)

    # x_mb_recon = self.decode(z)
    # x_mb_recon = x_mb_recon.reshape(n, x.size(0), -1)

    # if self.flags.loss_function == 'bce':
    # lf = nn.BCEWithLogitsLoss(reduction='none')
    # elif self.flags.loss_function == 'mse':
    # lf = nn.MSELoss(reduction='none')
    # else:
    # raise NotImplemented
    # log_p_x_z = -lf(x_mb_recon, x.reshape(x.size(0), -1).repeat((n, 1, 1))).sum(-1)

    # if self.distribution == 'normal':
    # log_p_z = torch.stack([p_z.log_prob(z__).sum(-1) for p_z, z__ in zip(p_zs, z_parts)], dim=-1).sum(-1)
    # log_q_z_x = torch.stack([q_z.log_prob(z__).sum(-1) for q_z, z__ in zip(q_zs, z_parts)], dim=-1).sum(-1)
    # elif self.distribution == 'vmf':
    # log_p_z = torch.stack([p_z.log_prob(z__) for p_z, z__ in zip(p_zs, z_parts)], dim=-1).sum(-1)
    # log_q_z_x = torch.stack([q_z.log_prob(z__) for q_z, z__ in zip(q_zs, z_parts)], dim=-1).sum(-1)
    # else:
    # raise NotImplementedError

    # return ((log_p_x_z + log_p_z.to(self.device) - log_q_z_x).t().logsumexp(-1) - np.log(n)).mean()

    def forward(self, x, n=None):

        z_means, z_vars = self.encode(x)
        if torch.isnan(z_means[0]).sum() > 0 or torch.isnan(z_vars[0]).sum() > 0:
            return (None, None), None, None

        q_zs, p_zs, q_zs_sample = self.reparameterize(z_means, z_vars)

        # sample z1, z2, .., zk and concatenate
        # z = torch.cat([q_z.rsample(torch.Size() if n is None else torch.Size([n])) for q_z in q_zs], dim=-1)  # slow
        z = torch.cat([torch.cat(torch.chunk(q_z.rsample(), int(c), dim=0), dim=-1)
                       for q_z, c in zip(q_zs_sample, self.z_counts)], dim=-1)
        # z_parts = list(torch.split(z, tuple(self.z_unique.repeat(self.z_counts)), -1))
        mu, sigma_square = self.decode(z)

        return (q_zs, p_zs), z, mu, sigma_square

In [55]:
def log_likelihood(model, x, n=10):
    z_mean, z_var = model.encode(x)
    q_zs, p_zs, q_zs_sample = model.reparameterize(z_mean, z_var)
    z = q_z.rsample(torch.Size([n]))
    # Need to add a line here like z = torch.cat() in the forward_pass(self, x, n=None)
    mu_, sigma_square_ = model.decode(z)

    # In scPhere used tf.reduce_mean()
    return torch.mean(log_likelihood_student(x, mu_, sigma_square_, df=2.0))

In [56]:
def log_likelihood_student(x, mu, sigma_square, df=2.0):
    sigma = torch.sqrt(sigma_square)

    dist = torch.distributions.studentT.StudentT(df=df,
                                                 loc=mu,
                                                 scale=sigma)

    # return tf.reduce_sum(dist.log_prob(x), reduction_indices=1)
    return torch.mean(dist.log_prob(x), dim=1)

In [57]:
def train(model, optimizer, device='cuda'):
    # for i, (x_mb, y_mb) in enumerate(train_loader):
    for i in range(0, X.shape[0], 128):
        # for x in X:
        if i + 128 < X.shape[0]:
            x = X[i: i + 128, :].to(device)
        else:
            x = X[i:, :].to(device)
        optimizer.zero_grad()

        # dynamic binarization
        # x_mb = (x_mb > torch.distributions.Uniform(0, 1).sample(x_mb.shape)).float()

        # _, (q_z, p_z), _, x_mb_ = model(x_mb.reshape(-1, 784))
        (q_zs, p_zs), z, mu, sigma_square = model(x)

        # loss_recon = nn.BCEWithLogitsLoss(reduction='none')(x_reconstructed, x).sum(-1).mean()
        # loss_recon = log_likelihood(x, z_mean, z_var)
        # library_size = torch.sum(x, dim=1, keepdim=True)
        # mu = mu * library_size
        # loss_recon = log_likelihood_nb(
        # x,
        # mu,
        # sigma_square,
        # ).sum(-1).mean()
        # loss_recon = torch.mean(log_likelihood_nb(x, mu, sigma_square, eps=1e-10))
        loss_recon = torch.mean(log_likelihood_student(x, mu, sigma_square, df=5.0))

        if model.distribution == 'normal':
            loss_KL = torch.distributions.kl.kl_divergence(q_z, p_z).sum(-1).mean()
        elif model.distribution == 'vmf':
            # loss_KL = torch.distributions.kl.kl_divergence(q_z, p_z).mean()
            loss_KL = model.kl_divergence(q_zs, p_zs)
        else:
            raise NotImplemented
        loss = loss_recon - loss_KL
        # loss = loss_recon + loss_KL

        loss.backward()
        optimizer.step()

In [58]:
def test(model, optimizer, device='cuda'):
    print_ = defaultdict(list)
    # for x_mb, y_mb in test_loader:
    for i in range(0, X.shape[0], 128):
        if i + 128 < X.shape[0]:
            x = X[i:i + 128, :].to(device)
        else:
            x = X[i:, :].to(device)

        # dynamic binarization
        # x_mb = (x_mb > torch.distributions.Uniform(0, 1).sample(x_mb.shape)).float()

        # _, (q_z, p_z), _, x_mb_ = model(x_mb.reshape(-1, 784))
        (q_zs, p_zs), z, mu_, sigma_square_ = model(x)

        # print_['recon loss'].append(float(nn.BCEWithLogitsLoss(reduction='none')(x_mb_,
        # x_mb.reshape(-1, 784)).sum(-1).mean().data))
        # library_size = torch.sum(x, dim=1, keepdim=True)
        # mu_ = mu_* library_size
        # print_['recon loss'].append(float(torch.mean(log_likelihood_nb(x, mu_, sigma_square_)).data))
        print_['recon loss'].append(float(torch.mean(log_likelihood_student(x, mu_, sigma_square_, df=5.0)).data))

        if model.distribution == 'normal':
            print_['KL'].append(float(torch.distributions.kl.kl_divergence(q_z, p_z).sum(-1).mean().data))
        elif model.distribution == 'vmf':
            # print_['KL'].append(float(torch.distributions.kl.kl_divergence(q_z, p_z).mean().data))
            print_['KL'].append(float(model.kl_divergence(q_zs, p_zs).mean().data))
        else:
            raise NotImplemented

        print_['ELBO'].append(print_['recon loss'][-1] - print_['KL'][-1])
        # print_['LL'].append(float(log_likelihood(model, x).data))

    print({k: np.mean(v) for k, v in print_.items()})

Using model

In [59]:
# hidden dimension and dimension of latent space
H_DIM = 128
Z_DIM = 2
torch.autograd.set_detect_anomaly(True)
X = torch.tensor(adata.X)
X = X.float()
n_gene = X.numpy().shape[1]

# normal VAE
#modelN = ModelVAE(n_gene=n_gene, z_dim=Z_DIM, encoder_layer=None, decoder_layer=None, distribution='normal', x=X)
# print(modelN.parameters)
# optimizerN = optim.Adam(modelN.parameters(), lr=1e-3)

print('##### Normal VAE #####')

n_epochs = 20
# training for 1 epoch
# for i in range(0, n_epochs):
# train(modelN, optimizerN)

# test
# test(modelN, optimizerN)

print()

save_file_names = {"v1": [1, 1],
                   "v2": [3, 2],
                   "v3": [10, 10, 10, 10],
                   "v4": [20, 10, 6, 1],
                   "v5": [15, 10, 4, 3, 2, 1],
                   "v6": [20, 20]
                   }
# for z_dim in [2, 5, 10, 20]:
#     for Z_DIMS in [[z_dim, z_dim], [z_dim, z_dim, z_dim], [z_dim, z_dim, z_dim, z_dim, z_dim]]:
# for Z_DIMS in [[1, 1],
               # [3, 2],
               # [9, 9, 9, 9],
               # [10, 10, 10, 10],
               # [20, 10, 6, 1],
               # [15, 10, 4, 3, 2, 1],
               # [20, 20]]:
# save_file_names = {}
for version in save_file_names:
    Z_DIMS = save_file_names[version]
    # Z_DIMS = [2, 2]
    # hyper-spherical  VAE
    encoder_layer = None
    decoder_layer = None
    learning_rate = 1e-5
    use_l2_regularization = False
    modelS = ProductSpaceVAE(n_gene=n_gene, z_dims=[z + 1 for z in Z_DIMS], encoder_layer=encoder_layer, decoder_layer=decoder_layer,
                             distribution='vmf', device='cuda').to('cuda')
    # print(modelS.parameters)
    # optimizerS = optim.SGD(modelS.parameters(), lr=1e-5)
    optimizerS = optim.Adam(modelS.parameters(), lr=learning_rate)

    print('##### Product Space Hyper-spherical VAE #####')

    s_epochs = 10
    # training for 1 epoch
    print(Z_DIMS)
    for i in range(0, s_epochs):
        train(modelS, optimizerS)
    save_model(f"product_svae_{data}_{version}")
    # test
    test(modelS, optimizerS)

##### Normal VAE #####

##### Product Space Hyper-spherical VAE #####
[1, 1]




{'recon loss': -23.196541244279658, 'KL': 2.7919460088569235, 'ELBO': -25.988487253136583}
##### Product Space Hyper-spherical VAE #####
[3, 2]
{'recon loss': -12.681391839058168, 'KL': 3.6873143132206234, 'ELBO': -16.36870615227879}
##### Product Space Hyper-spherical VAE #####
[10, 10, 10, 10]
{'recon loss': -21.91185671046709, 'KL': 7.335589235686695, 'ELBO': -29.24744594615378}
##### Product Space Hyper-spherical VAE #####
[20, 10, 6, 1]
{'recon loss': -20.7835997067965, 'KL': 6.546653262891863, 'ELBO': -27.330252969688363}
##### Product Space Hyper-spherical VAE #####
[15, 10, 4, 3, 2, 1]
{'recon loss': -31.932040699658263, 'KL': 10.44656860674231, 'ELBO': -42.37860930640058}
##### Product Space Hyper-spherical VAE #####
[20, 20]
{'recon loss': -11.908658065511926, 'KL': 2.566489924073334, 'ELBO': -14.475147989585261}


In [60]:
def save_model(model_filename):
  torch.save({
      "Z_DIMS": Z_DIMS,
      "encoder_layer": encoder_layer,
      "decoder_layer": decoder_layer,
      "s_epochs": s_epochs,
      "learning_rate": learning_rate,
      "use_l2_regularization": False,
      "model_state_dict": modelS.state_dict(),
      "optimizer_state_dict": optimizerS.state_dict()
  },
  model_filename)

In [61]:
saved_model = "product_svae_b_cells_v1"
if saved_model:
  checkpoint = torch.load(saved_model)
  n_gene = 1196
  Z_DIMS = checkpoint["Z_DIMS"]
  encoder_layer = checkpoint["encoder_layer"]
  decoder_layer = checkpoint["decoder_layer"]
  learning_rate = checkpoint["learning_rate"]

  modelS = ProductSpaceVAE(n_gene=n_gene, z_dims=[z + 1 for z in Z_DIMS], encoder_layer=encoder_layer, decoder_layer=decoder_layer,
                          distribution='vmf', device='cuda').to('cuda')
  optimizerS = optim.Adam(modelS.parameters(), lr=learning_rate)

  modelS.load_state_dict(checkpoint['model_state_dict'])
  optimizerS.load_state_dict(checkpoint['optimizer_state_dict'])

  test(modelS, optimizerS)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x1816 and 1196x128)