In [1]:
import os
import sys
sys.path.append("/home/roh3635/hyperbolic-cancer/mvae/mt/mvae")

import torch
import mvae.mt.mvae.utils as utils
import mvae.mt.mvae.models as models

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist
import math

class MixedCurvatureVAE(nn.Module):
    def __init__(
        self, 
        n_gene, 
        n_batch=None, 
        z_dim=2,
        encoder_layer=None, 
        decoder_layer=None, 
        latent_dist="vmf", 
        observation_dist="nb",
        batch_invariant=False
    ):
        super(MixedCurvatureVAE, self).__init__()
        
        if encoder_layer is None:
            encoder_layer = [128, 64, 32]
        if decoder_layer is None:
            decoder_layer = [32, 128]
        
        self.batch_invariant = batch_invariant
        self.n_input_feature = n_gene
        self.z_dim = z_dim
        self.encoder_layer = encoder_layer
        self.decoder_layer = decoder_layer
        self.latent_dist = latent_dist
        self.observation_dist = observation_dist
        
        if self.latent_dist == "vmf":
            self.z_dim += 1
        
        if not isinstance(n_batch, list):
            n_batch = [n_batch]
        
        self.n_batch = n_batch
        
        self._build_encoder()
        self._build_decoder()
    
    def _build_encoder(self):
        layers = []
        
        input_size = self.n_input_feature
        if not self.batch_invariant:
            input_size += sum(self.n_batch)
        
        in_features = input_size
        for units in self.encoder_layer:
            layers.append(nn.Linear(in_features, units))
            layers.append(nn.BatchNorm1d(units))
            layers.append(nn.ELU())
            in_features = units
        
        self.encoder_hidden = nn.Sequential(*layers)
        
        if self.latent_dist == "normal":
            self.z_mu_layer = nn.Linear(self.encoder_layer[-1], self.z_dim)
            self.z_sigma_square_layer = nn.Linear(self.encoder_layer[-1], self.z_dim)
        elif self.latent_dist == "vmf":
            self.z_mu_layer = nn.Linear(self.encoder_layer[-1], self.z_dim)
            self.z_sigma_square_layer = nn.Linear(self.encoder_layer[-1], 1)
        elif self.latent_dist == "wn":
            self.z_mu_layer = nn.Linear(self.encoder_layer[-1], self.z_dim)
            self.z_sigma_square_layer = nn.Linear(self.encoder_layer[-1], self.z_dim)
        else:
            raise NotImplementedError()
    
    def _build_decoder(self):
        layers = []
        
        input_size = self.z_dim + sum(self.n_batch)
        
        in_features = input_size
        for units in self.decoder_layer:
            layers.append(nn.Linear(in_features, units))
            layers.append(nn.BatchNorm1d(units))
            layers.append(nn.ELU())
            in_features = units
        
        self.decoder_hidden = nn.Sequential(*layers)
        
        self.mu_layer = nn.Linear(self.decoder_layer[-1], self.n_input_feature)
        self.sigma_square_layer = nn.Linear(self.decoder_layer[-1], self.n_input_feature)
    
    def multi_one_hot(self, index_tensor, depth_list):
        # Convert multiple batch indices to one-hot encodings
        batch_size = index_tensor.size(0)
        one_hot_tensor = torch.zeros(batch_size, sum(depth_list), device=index_tensor.device)
        
        start_idx = 0
        for col in range(len(depth_list)):
            indices = index_tensor[:, col]
            for i in range(batch_size):
                if indices[i] < depth_list[col]:
                    one_hot_tensor[i, start_idx + indices[i]] = 1
            start_idx += depth_list[col]
        
        return one_hot_tensor
    
    def _encoder(self, x, batch):
        if self.observation_dist == "nb":
            x = torch.log1p(x)
            
            if self.latent_dist == "vmf":
                x = F.normalize(x, p=2, dim=-1)
        
        if not self.batch_invariant:
            x = torch.cat([x, batch], dim=1)
        
        h = self.encoder_hidden(x)
        
        if self.latent_dist == "normal":
            z_mu = self.z_mu_layer(h)
            z_sigma_square = F.softplus(self.z_sigma_square_layer(h))
        elif self.latent_dist == "vmf":
            z_mu = self.z_mu_layer(h)
            z_mu = F.normalize(z_mu, p=2, dim=-1)
            z_sigma_square = F.softplus(self.z_sigma_square_layer(h)) + 1
            z_sigma_square = torch.clamp(z_sigma_square, 1, 10000)
        elif self.latent_dist == "wn":
            z_mu = self.z_mu_layer(h)
            z_mu = self._polar_project(z_mu)
            z_sigma_square = F.softplus(self.z_sigma_square_layer(h))
        else:
            raise NotImplementedError
        
        return z_mu, z_sigma_square
    
    def _decoder(self, z, batch):
        z = torch.cat([z, batch], dim=1)
        
        h = self.decoder_hidden(z)
        
        if self.observation_dist == "nb":
            mu = F.softmax(self.mu_layer(h), dim=1)
            mu = mu * self.library_size
            
            sigma_square = F.softplus(self.sigma_square_layer(h))
            sigma_square = torch.mean(sigma_square, dim=0)
        else:
            mu = self.mu_layer(h)
            sigma_square = F.softplus(self.sigma_square_layer(h))
        
        sigma_square = torch.clamp(sigma_square, self.EPS, self.MAX_SIGMA_SQUARE)
        
        return mu, sigma_square
    
    def _clip_min_value(self, x, eps=1e-6):
        return F.relu(x - eps) + eps
    
    def _polar_project(self, x):
        x_norm = torch.sum(torch.square(x), dim=1, keepdim=True)
        x_norm = torch.sqrt(self._clip_min_value(x_norm))
        
        x_unit = x / x_norm
        x_norm = torch.clamp(x_norm, 0, 32)
        
        z = torch.cat([
            torch.cosh(x_norm), 
            torch.sinh(x_norm) * x_unit
        ], dim=1)
        
        return z
    
    def _depth_regularizer(self, x, batch):
        with torch.no_grad():
            rate = x * 0.2
            samples = torch.poisson(rate)
        
        x_perturbed = F.relu(x - samples)
        z_mu_hat, _ = self._encoder(x_perturbed, batch)
        
        mean_diff = torch.sum(torch.pow(self.z_mu - z_mu_hat, 2), dim=1)
        loss = torch.mean(mean_diff)
        
        return loss
    
    def log_likelihood_nb(self, x, mu, sigma_square, eps=1e-10):
        log_theta_mu_eps = torch.log(sigma_square + mu + eps)
        res = (
            torch.lgamma(x + sigma_square) - torch.lgamma(x + 1) - torch.lgamma(sigma_square) +
            sigma_square * (torch.log(sigma_square) - log_theta_mu_eps) +
            x * (torch.log(mu) - log_theta_mu_eps)
        )
        return res
    
    def log_likelihood_student(self, x, mu, sigma_square, df=5.0):
        df_halved = df / 2
        return (
            torch.lgamma(df_halved + 0.5) - torch.lgamma(df_halved) -
            0.5 * torch.log(math.pi * df * sigma_square) -
            (df_halved + 0.5) * torch.log(1 + (x - mu)**2 / (df * sigma_square))
        )
    
    def forward(self, x, batch_id):
        if len(self.n_batch) > 1:
            batch = self.multi_one_hot(batch_id, self.n_batch)
        else:
            batch = F.one_hot(batch_id, self.n_batch[0]).float()
        
        self.library_size = torch.sum(x, dim=1, keepdim=True)
        
        self.z_mu, self.z_sigma_square = self._encoder(x, batch)
        
        # Sample latent variable
        if self.latent_dist == "normal":
            self.q_z = dist.Normal(self.z_mu, torch.sqrt(self.z_sigma_square))
            self.z = self.q_z.rsample()
            self.p_z = dist.Normal(torch.zeros_like(self.z), torch.ones_like(self.z))
            kl = dist.kl_divergence(self.q_z, self.p_z).sum(dim=1)
            self.kl = torch.mean(kl)
        elif self.latent_dist == 'vmf':
            self.q_z = VonMisesFisher(self.z_mu, self.z_sigma_square)
            self.z = self.q_z.sample()
            self.p_z = HypersphericalUniform(self.z_dim - 1, dtype=x.dtype)
            kl = self.q_z.kl_divergence(self.p_z)
            self.kl = torch.mean(kl)
        elif self.latent_dist == 'wn':
            self.q_z = HyperbolicWrappedNorm(self.z_mu, self.z_sigma_square)
            self.z = self.q_z.sample()
            tmp = self._polar_project(torch.zeros_like(self.z_sigma_square))
            self.p_z = HyperbolicWrappedNorm(tmp, torch.ones_like(self.z_sigma_square))
            kl = self.q_z.log_prob(self.z) - self.p_z.log_prob(self.z)
            self.kl = torch.mean(kl)
        else:
            raise NotImplementedError
        
        # Decoder
        self.mu, self.sigma_square = self._decoder(self.z, batch)
        
        # Depth regularization
        self.depth_loss = self._depth_regularizer(x, batch)
        
        # ELBO calculation
        if