In [57]:
import os
import torch
import gc
torch.cuda.empty_cache()
gc.collect()

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= ""  # Set the GPU

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

In [58]:
import numpy as np

# Generate toy RNA-seq data
num_genes = 10
num_samples = 5
gene_means = np.random.normal(0, 1, size=num_genes)
gene_vars = np.random.gamma(1, 1, size=num_genes)
data = np.random.normal(gene_means, np.sqrt(gene_vars), size=(num_samples, num_genes))
data

array([[ 1.24571316,  1.20752331, -1.34911164, -1.08253465,  1.49166842,
         0.20202308, -0.65336259, -0.13298715,  0.96503092, -1.55293706],
       [ 1.0871331 ,  0.44521698, -0.98670451, -2.23006552, -0.41548319,
         0.25337438,  0.44956698, -0.86520142, -0.0547919 ,  0.79194314],
       [ 1.97585154,  1.08755116, -0.87196726, -0.53131319,  0.75324479,
         0.32084094, -2.35121855, -0.74390211,  0.64182779, -1.85149413],
       [ 2.01916171,  0.65104069, -0.71712066, -0.95691518,  1.98095011,
         0.43648931, -0.7118163 , -0.60416282, -1.14906209, -0.23826607],
       [ 1.35141038,  1.3247779 , -0.22267992, -0.79041832,  2.69019775,
         0.49025555,  0.71649408, -0.42653473, -1.27791003, -0.08681777]])

In [59]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import Adam
from torch.nn import functional as F
from typing import Tuple
from sklearn.preprocessing import StandardScaler
import scglue.models as GLUE
import pandas as pd

# Define the RNAseq data
n_samples = 100
n_genes = 200
rna_data = torch.randn(n_genes, n_samples)

# Define the data loader
data = TensorDataset(rna_data)
data_loader = DataLoader(data, batch_size=10, shuffle=True)


In [60]:
# # Define the encoder and decoder
# latent_dim = 2
# encoder = GLUE.glue.DataEncoder(n_genes)
# decoder = GLUE.glue.DataDecoder(n_genes)

# # Define the prior
# prior = GLUE.glue.Prior()

# # Define the optimizer
# lr = 1e-3
# optimizer = Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=lr)

# # Define the training loop
# n_epochs = 100
# for epoch in range(n_epochs):
#     epoch_loss = 0.0
#     for batch_idx, batch in enumerate(data_loader):
#         # Zero the gradients
#         optimizer.zero_grad()

#         # Get the data
#         x = batch[0]

#         # Forward pass
#         q_z_x, z_l = encoder(x, torch.empty(0), lazy_normalizer=False)
#         z = q_z_x.rsample()
#         p_x_z = decoder(*z.shape, z, torch.zeros_like(z), torch.empty(0))

#         # Compute the loss
#         kl_divergence = torch.distributions.kl.kl_divergence(q_z_x, prior()).mean()
#         reconstruction_loss = -p_x_z.log_prob(x).mean()
#         loss = kl_divergence + reconstruction_loss

#         # Backward pass and optimization step
#         loss.backward()
#         optimizer.step()

#         # Update the epoch loss
#         epoch_loss += loss.item()

#     # Print the epoch loss
#     print(f"Epoch {epoch}: Loss = {epoch_loss / (batch_idx+1)}")


In [65]:
import torch
import torch.nn.functional as F
import torch.distributions as D
from abc import abstractmethod
from typing import Optional, Tuple

EPS = 1e-8


class DataEncoder(torch.nn.Module):
    def __init__(
            self, in_features: int, out_features: int,
            h_depth: int = 2, h_dim: int = 256,
            dropout: float = 0.2
    ) -> None:
        super().__init__()
        self.h_depth = h_depth
        ptr_dim = in_features
        for layer in range(self.h_depth):
            setattr(self, f"linear_{layer}", torch.nn.Linear(ptr_dim, h_dim))
            setattr(self, f"act_{layer}", torch.nn.LeakyReLU(negative_slope=0.2))
            setattr(self, f"bn_{layer}", torch.nn.BatchNorm1d(h_dim))
            setattr(self, f"dropout_{layer}", torch.nn.Dropout(p=dropout))
            ptr_dim = h_dim
        self.loc = torch.nn.Linear(ptr_dim, out_features)
        self.std_lin = torch.nn.Linear(ptr_dim, out_features)

    def compute_l(self, x: torch.Tensor) -> Optional[torch.Tensor]:
        return None

    def normalize(
            self, x: torch.Tensor, l: Optional[torch.Tensor]
    ) -> torch.Tensor:
        return x
    
    def forward(
            self, x: torch.Tensor, xrep: torch.Tensor,
            lazy_normalizer: bool = True
    ) -> Tuple[D.Normal, Optional[torch.Tensor]]:
        if xrep.numel():
            l = None if lazy_normalizer else self.compute_l(x)
            ptr = xrep
        else:
            l = self.compute_l(x)
            ptr = self.normalize(x, l)
        for layer in range(self.h_depth):
            ptr = getattr(self, f"linear_{layer}")(ptr)
            ptr = getattr(self, f"act_{layer}")(ptr)
            ptr = getattr(self, f"bn_{layer}")(ptr)
            ptr = getattr(self, f"dropout_{layer}")(ptr)
        loc = self.loc(ptr)
        std = F.softplus(self.std_lin(ptr)) + EPS
        return D.Normal(loc, std), l


class DataDecoder(torch.nn.Module):
    def __init__(
            self, out_features: int, n_batches: int
    ) -> None:
        super().__init__()
        self.scale_lin = torch.nn.Parameter(torch.zeros(1, n_batches, out_features))
        self.bias = torch.nn.Parameter(torch.zeros(n_batches, out_features))
        self.log_theta = torch.nn.Parameter(torch.zeros(n_batches, out_features))
        self.lin_dec = torch.nn.Linear(out_features, out_features)

    def forward(
            self, u: torch.Tensor, 
            # v: torch.Tensor,
            b: torch.Tensor, 
            l: torch.Tensor
    ) -> D.NegativeBinomial:
        # scale = F.softplus(self.scale_lin[b])
        logit_mu = self.lin_dec(u)
        print(logit_mu.size())
        print(u.size())
        # logit_mu = scale * (u @ v.t()) + self.bias[b]
        mu = F.softmax(logit_mu, dim=1) * l
        log_theta = self.log_theta[b]
        return D.NegativeBinomial(
            log_theta.exp(),
            logits=(mu + EPS).log() - log_theta
        )



class Prior(torch.nn.Module):
    def __init__(
            self, loc: float = 0.0, std: float = 1.0
    ) -> None:
        super().__init__()
        loc = torch.as_tensor(loc, dtype=torch.get_default_dtype())
        std = torch.as_tensor(std, dtype=torch.get_default_dtype())
        self.register_buffer("loc", loc)
        self.register_buffer("std", std)

    def forward(self) -> D.Normal:
        return D.Normal(self.loc, self.std)


In [62]:
# from sc import *
num_batch = 5
latent_feature = 10
num_epochs = 100
encoder = DataEncoder(in_features=n_samples, out_features=latent_feature)
decoder = DataDecoder(out_features=latent_feature, n_batches=num_batch)
# # Define the optimizer
lr = 0.1
optimizer = Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=lr)
prior = Prior()


In [63]:
def train_autoencoder(
    x: torch.Tensor,
    encoder: torch.nn.Module,
    decoder: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    batch_size: int,
    prior: Optional[torch.nn.Module] = None,
    beta: float = 1.0,
    num_epochs: int = 1000
) -> Tuple[torch.Tensor, torch.Tensor]:
    # Set the model to training mode
    encoder.train()
    decoder.train()

    # Create a data loader
    dataset = x
    loader = data_loader

    # Train the model
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for batch in loader:
            optimizer.zero_grad()

            # Forward pass through the encoder and compute the latent distribution
            x_batch = batch[0]
            q_z_x, _ = encoder(x_batch, torch.tensor([]), lazy_normalizer=True)

            # Sample a latent variable z from the distribution
            z = q_z_x.rsample()

            # If a prior is given, compute the KL divergence between q(z|x) and p(z)
            if prior is not None:
                p_z = prior()
                kl_div = D.kl.kl_divergence(q_z_x, p_z).sum(dim=1)
                kl_loss = beta * kl_div.mean()
            else:
                kl_loss = 0.0

            # Decode the latent variable to get the reconstructed data distribution
            b = torch.zeros(batch_size, dtype=torch.long)

            p_x_z = decoder(z, b, l=torch.ones_like(x_batch))

            # Compute the negative log likelihood loss between the original data and the reconstructed data
            x_log_probs = p_x_z.log_prob(x_batch).sum(dim=1)
            nll_loss = -x_log_probs.mean()

            # Compute the total loss and backpropagate the gradients
            loss = nll_loss + kl_loss
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item() * x_batch.size(0)

        print(f"Epoch {epoch+1}/{num_epochs}: Loss = {epoch_loss/len(dataset):.6f}")

    # Set the model to evaluation mode
    encoder.eval()
    decoder.eval()

    # Encode the data to get the latent variables
    with torch.no_grad():
        q_z_x, _ = encoder(x, torch.tensor([]), lazy_normalizer=True)
        z = q_z_x.rsample()

        # Decode the latent variables to get the reconstructed data distribution
        b = torch.zeros(x.size(0), dtype=torch.long)
        p_x_z = decoder(z, b, l=torch.ones_like(x))

        # Compute the negative log likelihood loss between the original data and the reconstructed data
        x_log_probs = p_x_z.log_prob(x).sum(dim=1)
        nll_loss = -x_log_probs.mean()

    return z, p_x_z.mean, nll_loss


In [64]:
test_vae = train_autoencoder(data, encoder, decoder, optimizer, batch_size=num_batch, prior=prior, num_epochs=100)

torch.Size([10, 10])
torch.Size([10, 10])


RuntimeError: The size of tensor a (10) must match the size of tensor b (100) at non-singleton dimension 1

In [None]:
def train_autoencoder(
        data: torch.Tensor, model: torch.nn.Module,
        epochs: int = 100, batch_size: int = 32,
        learning_rate: float = 1e-3
) -> Tuple[torch.nn.Module, torch.Tensor]:
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    n_batches = (data.size(0) + batch_size - 1) // batch_size
    for epoch in range(1, epochs+1):
        epoch_loss = 0.0
        for batch in range(n_batches):
            start = batch * batch_size
            end = min((batch+1) * batch_size, data.size(0))
            x = data[start:end]
            optimizer.zero_grad()
            dist, _ = model(x, x.new_empty(0))
            z = dist.rsample()
            recon_dist = model.decoder(z, z.new_empty(0), torch.zeros_like(x), torch.ones_like(x))
            loss = -recon_dist.log_prob(x).mean()
            kl_loss = D.kl_divergence(dist, model.prior()).mean()
            loss += kl_loss
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * (end - start)
        epoch_loss /= data.size(0)
        print(f"Epoch {epoch}/{epochs}: Loss={epoch_loss:.4f}")
    return model, z.detach()
