In [15]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors

<IPython.core.display.Javascript object>

In [16]:
class CVAE(tf.keras.Model):
    def __init__(self, input_dim, latent_dim, hidden_dim):
        super(CVAE, self).__init__()
        
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        
        self.encoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(self.input_dim,)),
            tf.keras.layers.Dense(self.hidden_dim, activation='relu'),
            tf.keras.layers.Dense(self.hidden_dim, activation='relu'),
            tf.keras.layers.Dense(self.latent_dim * 2)
        ])
        
        self.decoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(self.latent_dim,)),
            tf.keras.layers.Dense(self.hidden_dim, activation='relu'),
            tf.keras.layers.Dense(self.hidden_dim, activation='relu'),
            tf.keras.layers.Dense(self.input_dim)
        ])
        
        self.coupled_encoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(self.input_dim + self.input_dim,)),
            tf.keras.layers.Dense(self.hidden_dim, activation='relu'),
            tf.keras.layers.Dense(self.hidden_dim, activation='relu'),
            tf.keras.layers.Dense(self.latent_dim * 2)
        ])
    
    def encode(self, x):
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        return mean, logvar
    
    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * 0.5) + mean
    
    def decode(self, z):
        return self.decoder(z)
    
    def coupled_encode(self, x):
        return self.coupled_encoder(x)
    
    def call(self, x):
        x1, x2 = x
        
        # encode scRNA-seq data and sample from the distribution
        mean1, logvar1 = self.encode(x1)
        z1 = self.reparameterize(mean1, logvar1)
        
        # encode scATAC-seq data and sample from the distribution
        mean2, logvar2 = self.encode(x2)
        z2 = self.reparameterize(mean2, logvar2)
        
        # concatenate the two latent variables
        z = tf.concat([z1, z2], axis=1)
        
        # decode from the concatenated latent variable
        x_recon = self.decode(z)
        
        return x_recon, mean1, logvar1, mean2, logvar2


<IPython.core.display.Javascript object>

In [17]:
def compute_loss(model, x, k=5):
    x1, x2 = x
    x_recon, mean1, logvar1, mean2, logvar2 = model([x1, x2])
    
    # compute reconstruction loss
    recon_loss = tf.reduce_mean(tf.square(x_recon - x1))
    
    # compute KL divergence loss
    kl_loss1 = -0.5 * tf.reduce_mean(1 + logvar1 - tf.square(mean1) - tf.exp(logvar1))
    kl_loss2 = -0.5 * tf.reduce_mean(1 + logvar2 - tf.square(mean2) - tf.exp(logvar2))
    
    # compute matching loss
    z1 = model.encode(x1)[0]
    z2 = model.encode(x2)[0]
    
    # compute k-nearest neighbors for scRNA-seq data
    nbrs1 = NearestNeighbors(n_neighbors=k, algorithm='ball_tree').fit(z1)
    distances1, indices1 = nbrs1.kneighbors(z2)
    
    # compute k-nearest neighbors for scATAC-seq data
    nbrs2 = NearestNeighbors(n_neighbors=k, algorithm='ball_tree').fit(z2)
    distances2, indices2 = nbrs2.kneighbors(z1)
    
    # compute optimal matching using Hungarian algorithm
    cost_matrix = np.zeros((x1.shape[0], x2.shape[0]))
    for i, j in zip(indices1.flatten(), range(indices1.shape[0])):
        cost_matrix[i, j] = distances1.flatten()[j]
    row_ind, col_ind = linear_sum_assignment(cost_matrix)
    
    # compute matching loss
    match_loss = np.mean(cost_matrix[row_ind, col_ind])
    
    return recon_loss, kl_loss1, kl_loss2, match_loss


<IPython.core.display.Javascript object>

In [18]:
# define optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

# define early stopping
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)

# train the model
def train(model, x_train, x_val, epochs, batch_size):
    train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(len(x_train)).batch(batch_size)
    val_dataset = tf.data.Dataset.from_tensor_slices(x_val).batch(batch_size)
    
    train_loss_results = []
    val_loss_results = []
    
    for epoch in range(epochs):
        train_loss_avg = tf.keras.metrics.Mean()
        val_loss_avg = tf.keras.metrics.Mean()
        
        # train the model
        for x in train_dataset:
            with tf.GradientTape() as tape:
                recon_loss, kl_loss1, kl_loss2, match_loss = compute_loss(model, x)
                loss = recon_loss + kl_loss1 + kl_loss2 + match_loss
                gradients = tape.gradient(loss, model.trainable_variables)
                optimizer.apply_gradients(zip(gradients, model.trainable_variables))
                train_loss_avg.update_state(loss)
        # compute validation loss
        for x in val_dataset:
            recon_loss, kl_loss1, kl_loss2, match_loss = compute_loss(model, x)
            loss = recon_loss + kl_loss1 + kl_loss2 + match_loss
            val_loss_avg.update_state(loss)

        train_loss_results.append(train_loss_avg.result())
        val_loss_results.append(val_loss_avg.result())

        # print progress
        if epoch % 10 == 0:
            print("Epoch {:03d}: Train Loss: {:.3f}, Val Loss: {:.3f}".format(epoch, train_loss_avg.result(), val_loss_avg.result()))

        # check for early stopping
        if len(val_loss_results) > early_stopping.patience:
            if val_loss_results[-1] >= val_loss_results[-1 * (early_stopping.patience + 1)]:
                print("Early stopping at epoch {:03d}".format(epoch))
                break

    return train_loss_results, val_loss_results


<IPython.core.display.Javascript object>

In [19]:
# encode new data into the latent space
x_rna_train = tf.convert_to_tensor(np.random.rand(500, 5000))
x_atac_train = tf.convert_to_tensor(np.random.rand(500, 5000))

x_rna_val = tf.convert_to_tensor(np.random.rand(100, 5000))
x_atac_val = tf.convert_to_tensor(np.random.rand(100, 5000))
# train the model
cvae = CVAE(input_dim=5000, latent_dim=64, hidden_dim=3)
train_loss_results, val_loss_results = train(
    cvae,
    [x_rna_train, x_atac_train],
    [x_rna_val, x_atac_val],
    epochs=100,
    batch_size=32,
)

# plot loss curves
plt.plot(train_loss_results, label="Train Loss")
plt.plot(val_loss_results, label="Val Loss")
plt.legend()
plt.show()

ValueError: Exception encountered when calling layer "cvae" "                 f"(type CVAE).

Input 0 of layer "sequential_1" is incompatible with the layer: expected shape=(None, 64), found shape=(500, 128)

Call arguments received by layer "cvae" "                 f"(type CVAE):
  • x=['tf.Tensor(shape=(500, 5000), dtype=float32)', 'tf.Tensor(shape=(500, 5000), dtype=float32)']

<IPython.core.display.Javascript object>

In [6]:
def encode(model, data):
    """
    Encode data into the latent space using the trained model.
    Args:
        model (CoupledVAE): Trained coupled variational autoencoder model.
        data (numpy.ndarray): Input data to be encoded.
    Returns:
        numpy.ndarray: Encoded latent space representation of the input data.
    """
    # split the input data into RNA and ATAC data
    x_rna, x_atac = data
    # encode the RNA and ATAC data into the latent space
    z_rna = model.encode(x_rna, modality='rna')
    z_atac = model.encode(x_atac, modality='atac')
    # concatenate the two latent space representations
    z = tf.concat([z_rna, z_atac], axis=0)
    return z.numpy()


In [None]:
# encode new data into the latent space
x_rna_new = np.random.rand(100, 5000)
x_atac_new = np.random.rand(80, 5000)
z_new = encode(cvae, [x_rna_new, x_atac_new])

# visualize the latent space
colors = np.concatenate([np.ones((100,1)), np.zeros((80,1))], axis=0) # assign different colors to RNA and ATAC data
plt.scatter(z_new[:,0], z_new[:,1], c=colors)
plt.show()


In [7]:
def match(model, data_atac):
    """
    Match scRNA-seq cells to scATAC-seq cells using the trained model.
    Args:
        model (CoupledVAE): Trained coupled variational autoencoder model.
        data_atac (numpy.ndarray): scATAC-seq data to match scRNA-seq cells to.
    Returns:
        numpy.ndarray: scRNA-seq data matched to the scATAC-seq cells.
    """
    # encode the scATAC-seq data into the latent space
    z_atac = model.encode(data_atac, modality='atac')
    # initialize an empty list to store the matched scRNA-seq cells
    data_rna_matched = []
    # loop over each scATAC-seq cell
    for i in range(z_atac.shape[0]):
        # compute the Euclidean distance between the scATAC-seq cell and all scRNA-seq cells in the latent space
        dist = tf.reduce_sum(tf.square(z_atac[i,:] - model.z_rna), axis=1)
        # find the index of the closest scRNA-seq cell
        index = tf.argmin(dist)
        # append the matched scRNA-seq cell to the list
        data_rna_matched.append(model.x_rna[index,:])
    # convert the list to a numpy array
    data_rna_matched = np.array(data_rna_matched)
    return data_rna_matched


In [None]:
# generate some new scATAC-seq data
x_atac_new = np.random.rand(50, 5000)

# match the new scATAC-seq data to scRNA-seq data
x_rna_matched = match(cvae, x_atac_new)

# compare the original scATAC-seq data to the matched scRNA-seq data
for i in range(x_atac_new.shape[0]):
    print("Original ATAC Cell {}: {}".format(i, x_atac_new[i,:10]))
    print("Matched RNA Cell {}: {}".format(i, x_rna_matched[i,:10]))
