In [1]:
# import os
# import sys
# sys.path.append("/home/roh3635/hyperbolic-cancer/mvae/mt/mvae")

import math
import numpy as np
import scipy.special
import torch
import torch.distributions as distributions
import torch.nn.functional as F

In [2]:
ln_2: torch.Tensor = math.log(2)
ln_pi: torch.Tensor = math.log(math.pi)
ln_2pi: torch.Tensor = ln_2 + ln_pi


class EuclideanNormal(distributions.Normal):
    def log_prob(self, value):
        return super().log_prob(value).sum(dim=-1)


class EuclideanUniform(distributions.Uniform):
    def log_prob(self, value):
        return super().log_prob(value).sum(dim=-1)
    

class IveFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, v, z):
        ctx.save_for_backward(z)
        ctx.v = v
        z_cpu = z.double().detach().cpu().numpy()
        if np.isclose(v, 0):
            output = scipy.special.i0e(z_cpu, dtype=z_cpu.dtype)
        elif np.isclose(v, 1):
            output = scipy.special.i1e(z_cpu, dtype=z_cpu.dtype)
        else:
            output = scipy.special.ive(v, z_cpu, dtype=z_cpu.dtype)
        return torch.tensor(output, dtype=z.dtype, device=z.device)

    @staticmethod
    def backward(ctx, grad_output):
        z = ctx.saved_tensors[-1]
        return None, grad_output * (ive(ctx.v - 1, z) - ive(ctx.v, z) * (ctx.v + z) / z)


def ive(v, z):
    return IveFunction.apply(v, z)


class HypersphericalUniform(distributions.Distribution):
    arg_constraints = {}
    support = distributions.constraints.real
    _mean_carrier_measure = 0

    def __init__(self, dim, validate_args, device="cuda"):
        super().__init__(torch.Size([dim]), validate_args=validate_args)
        self.dim = dim
        self.device = device
        self.normal = EuclideanNormal(0, 1)

    def rsample(self, sample_shape):
        output = self.normal.sample(sample_shape + torch.Size([1, self.dim + 1])).to(self.device)
        return F.normalize(output, dim=-1)

    def entropy(self):
        return self.__log_surface_area()

    def log_prob(self, x):
        return -torch.ones(x.shape[:-1], device=self.device) * self.__log_surface_area()

    def __log_surface_area(self):
        n = self.dim
        t = torch.tensor((n + 1.) / 2.)
        ret = ln_2 + t * ln_pi - torch.lgamma(t)
        return ret


# class VonMisesFisher(distributions.Distribution):
#     arg_constraints = {
#         "loc": torch.distributions.constraints.real,
#         "scale": torch.distributions.constraints.positive
#     }
#     support = torch.distributions.constraints.real
#     _mean_carrier_measure = 0

#     @property
#     def mean(self):
#         return self.loc * (ive(self.p / 2, self.scale) / ive(self.p / 2 - 1, self.scale))

#     @property
#     def stddev(self):
#         return self.scale
    
#     def __init__(self, loc, scale, validate_args=None, device="cuda"):
#         self.dtype = loc.dtype
#         self.loc = loc
#         assert loc.norm(p=2, dim=-1).allclose(torch.ones(loc.shape[:-1], device=loc.device))
#         self.scale = scale
#         assert (scale > 0).all()
#         self.device = loc.device
#         self.p = loc.shape[-1]

#         self.uniform = EuclideanUniform(0, 1)
#         self.hyperspherical_uniform_v = HypersphericalUniform(self.p - 2, device=self.device)

#         # Pre-compute Householder transformation
#         e1 = torch.tensor([1.] + [0.] * (loc.shape[-1] - 1), requires_grad=False, device=self.device)
#         self.u = F.normalize(e1 - self.loc)

#         super().__init__(self.loc.size(), validate_args=validate_args)

In [2]:
# The von-Mises-Fisher distribution in PyTorch

import math
import torch
import torch.distributions as distributions

from torch.distributions import constraints
from torch.distributions.distribution import Distribution


class VonMisesFisher(Distribution):
    arg_constraints = {
        "loc": constraints.real_vector,
        "scale": constraints.positive
    }
    support = constraints.real_vector
    has_rsample = True
    
    def __init__(self, loc, scale, validate_args=None):
        self.dtype = loc.dtype
        self.device = loc.device
        
        if validate_args:
            loc_norm = torch.norm(loc, dim=-1)
            if not torch.allclose(loc_norm, torch.ones_like(loc_norm), atol=1e-5):
                raise ValueError("`loc` must be normalized to unit length")
            if not torch.all(scale > 0):
                raise ValueError("`scale` must be positive")
                
        self.loc = loc
        self.scale = scale
        
        self.__m = loc.shape[-1]
        self.__e1 = torch.zeros(self.__m, device=self.device, dtype=self.dtype) # Reference vector
        self.__e1[0] = 1.0
        
        batch_shape = torch.broadcast_shapes(loc.shape[:-1], scale.shape)
        event_shape = torch.Size([self.__m])
        
        super(VonMisesFisher, self).__init__(batch_shape, event_shape, validate_args)
        
    def rsample(self, sample_shape=torch.Size()):
        shape = sample_shape + self.batch_shape
        
        if self.__m == 3:
            w = self._sample_w3(shape)
        else:
            w = self._sample_w_rej(shape)
            
        w = torch.clamp(w, -1 + 1e-6, 1 - 1e-6)
        
        v_shape = shape + torch.Size([self.__m - 1])
        v = torch.randn(v_shape, device=self.device, dtype=self.dtype)
        v = torch.nn.functional.normalize(v, dim=-1)
        
        # Combine to get a point on the sphere
        tmp = torch.sqrt(1.0 + w) * torch.sqrt(1.0 - w)
        x = torch.cat([w, tmp.unsqueeze(-1) * v], dim=-1)
        
        # Apply Householder rotation
        z = self._householder_rotation(x)
        return z
        
    def _sample_w3(self, shape):
        w_shape = shape + torch.Size([1])
        u = torch.rand(w_shape, device=self.device, dtype=self.dtype)
        u = torch.clamp(u, 1e-16, 1 - 1e-16)
        
        w = 1 + (torch.log(u) + torch.log(1 - u) - 2 * self.scale) / self.scale
        return w
        
    def _sample_w_rej(self, shape):
        w_shape = shape + torch.Size([1])
        m = float(self.__m)
        
        # Compute b
        tmp = torch.sqrt((4 * (self.scale ** 2)) + ((m - 1) ** 2))
        b = (m - 1.0) / (2.0 * self.scale + tmp)
        
        # Prepare for rejection sampling
        b_shape = shape + torch.Size([1])
        b = b.expand(b_shape)
        
        w = torch.zeros_like(b)
        e = torch.zeros_like(b)
        mask = torch.ones_like(b, dtype=torch.bool)
        
        # Rejection sampling loop - in PyTorch we must do this explicitly
        max_attempts = 100
        for _ in range(max_attempts):
            if not torch.any(mask):
                break
                
            # Sample from Beta distribution
            beta_dist = distributions.Beta((m - 1.0) / 2.0, (m - 1.0) / 2.0)
            e_new = beta_dist.sample(w_shape).to(self.device).type(self.dtype)
            
            # Sample uniform
            u = torch.rand(w_shape, device=self.device, dtype=self.dtype)
            
            # Compute w
            w_new = (1.0 - (1.0 + b) * e_new) / (1.0 - (1.0 - b) * e_new)
            x = (1.0 - b) / (1.0 + b)
            c = self.scale * x + (m - 1) * torch.log1p(-x**2)
            
            # Acceptance test
            tmp = torch.clamp(x * w_new, 0, 1 - 1e-16)
            log_accept = ((m - 1.0) * torch.log(1.0 - tmp) + self.scale * w_new - c) - torch.log(u)
            accept = log_accept > 0
            
            # Update values
            active_mask = mask & accept
            w = torch.where(active_mask, w_new, w)
            e = torch.where(active_mask, e_new, e)
            mask = mask & ~accept
            
        return w
        
    def _householder_rotation(self, x):
        u = torch.nn.functional.normalize(self.__e1 - self.loc, dim=-1)
        z = x - 2 * (u * x).sum(dim=-1, keepdim=True) * u
        return z
        
    def log_prob(self, x):
        return self._log_unnormalized_prob(x) - self._log_normalization()
        
    def _log_unnormalized_prob(self, x):
        if self.validate_args:
            x_norm = torch.norm(x, dim=-1)
            if not torch.allclose(x_norm, torch.ones_like(x_norm), atol=1e-3):
                raise ValueError("x must be normalized to unit length")
                
        return self.scale * torch.sum(self.loc * x, dim=-1)
        
    def _log_normalization(self):
        m = float(self.__m)
        return ((m / 2.0 - 1) * torch.log(self.scale) - 
                (m / 2.0) * math.log(2 * math.pi) - 
                (self.scale + torch.log(self._ive(m / 2.0 - 1, self.scale))))
                
    def entropy(self):
        m = float(self.__m)
        return (- self.scale * self._ive(m / 2.0, self.scale) / 
                self._ive(m / 2.0 - 1, self.scale) - 
                self._log_normalization())
                
    def mean(self):
        m = float(self.__m)
        return self.loc * (self._ive(m / 2.0, self.scale) / 
                          self._ive(m / 2.0 - 1, self.scale))
                          
    def mode(self):
        return self.mean()
        
    def _ive(self, v, z):
        """
        Modified Bessel function of the first kind exp(-|z|) I_v(z)
        
        Note: This is a simplified implementation. For a production system, 
        you might want to use a more efficient computation of the Bessel function.
        """
        # This is a placeholder implementation
        # In a real system, you would use a specialized numerical implementation
        # For example, you could use scipy.special.ive through torch.numpy_interop
        # or implement a more efficient version directly in PyTorch
        
        # For now, returning a simple approximation for small z
        # This will need to be replaced with a proper implementation
        return torch.exp(-torch.abs(z)) * (z/2)**v / torch.exp(torch.lgamma(v + 1))


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist
import math

class MixedCurvatureVAE(nn.Module):
    def __init__(
        self, 
        n_gene, 
        n_batch=None, 
        z_dim=2,
        encoder_layer=None, 
        decoder_layer=None, 
        latent_dist="vmf", 
        observation_dist="nb",
        batch_invariant=False
    ):
        super(MixedCurvatureVAE, self).__init__()
        
        if encoder_layer is None:
            encoder_layer = [128, 64, 32]
        if decoder_layer is None:
            decoder_layer = [32, 128]
        
        self.batch_invariant = batch_invariant
        self.n_input_feature = n_gene
        self.z_dim = z_dim
        self.encoder_layer = encoder_layer
        self.decoder_layer = decoder_layer
        self.latent_dist = latent_dist
        self.observation_dist = observation_dist
        
        if self.latent_dist == "vmf":
            self.z_dim += 1
        
        if not isinstance(n_batch, list):
            n_batch = [n_batch]
        
        self.n_batch = n_batch
        
        self._build_encoder()
        self._build_decoder()
    
    def _build_encoder(self):
        layers = []
        
        input_size = self.n_input_feature
        if not self.batch_invariant:
            input_size += sum(self.n_batch)
        
        in_features = input_size
        for units in self.encoder_layer:
            layers.append(nn.Linear(in_features, units))
            layers.append(nn.BatchNorm1d(units))
            layers.append(nn.ELU())
            in_features = units
        
        self.encoder_hidden = nn.Sequential(*layers)
        
        if self.latent_dist == "normal":
            self.z_mu_layer = nn.Linear(self.encoder_layer[-1], self.z_dim)
            self.z_sigma_square_layer = nn.Linear(self.encoder_layer[-1], self.z_dim)
        elif self.latent_dist == "vmf":
            self.z_mu_layer = nn.Linear(self.encoder_layer[-1], self.z_dim)
            self.z_sigma_square_layer = nn.Linear(self.encoder_layer[-1], 1)
        elif self.latent_dist == "wn":
            self.z_mu_layer = nn.Linear(self.encoder_layer[-1], self.z_dim)
            self.z_sigma_square_layer = nn.Linear(self.encoder_layer[-1], self.z_dim)
        else:
            raise NotImplementedError()
    
    def _build_decoder(self):
        layers = []
        
        input_size = self.z_dim + sum(self.n_batch)
        
        in_features = input_size
        for units in self.decoder_layer:
            layers.append(nn.Linear(in_features, units))
            layers.append(nn.BatchNorm1d(units))
            layers.append(nn.ELU())
            in_features = units
        
        self.decoder_hidden = nn.Sequential(*layers)
        
        self.mu_layer = nn.Linear(self.decoder_layer[-1], self.n_input_feature)
        self.sigma_square_layer = nn.Linear(self.decoder_layer[-1], self.n_input_feature)
    
    def multi_one_hot(self, index_tensor, depth_list):
        # Convert multiple batch indices to one-hot encodings
        batch_size = index_tensor.size(0)
        one_hot_tensor = torch.zeros(batch_size, sum(depth_list), device=index_tensor.device)
        
        start_idx = 0
        for col in range(len(depth_list)):
            indices = index_tensor[:, col]
            for i in range(batch_size):
                if indices[i] < depth_list[col]:
                    one_hot_tensor[i, start_idx + indices[i]] = 1
            start_idx += depth_list[col]
        
        return one_hot_tensor
    
    def _encoder(self, x, batch):
        if self.observation_dist == "nb":
            x = torch.log1p(x)
            
            if self.latent_dist == "vmf":
                x = F.normalize(x, p=2, dim=-1)
        
        if not self.batch_invariant:
            x = torch.cat([x, batch], dim=1)
        
        h = self.encoder_hidden(x)
        
        if self.latent_dist == "normal":
            z_mu = self.z_mu_layer(h)
            z_sigma_square = F.softplus(self.z_sigma_square_layer(h))
        elif self.latent_dist == "vmf":
            z_mu = self.z_mu_layer(h)
            z_mu = F.normalize(z_mu, p=2, dim=-1)
            z_sigma_square = F.softplus(self.z_sigma_square_layer(h)) + 1
            z_sigma_square = torch.clamp(z_sigma_square, 1, 10000)
        elif self.latent_dist == "wn":
            z_mu = self.z_mu_layer(h)
            z_mu = self._polar_project(z_mu)
            z_sigma_square = F.softplus(self.z_sigma_square_layer(h))
        else:
            raise NotImplementedError
        
        return z_mu, z_sigma_square
    
    def _decoder(self, z, batch):
        z = torch.cat([z, batch], dim=1)
        
        h = self.decoder_hidden(z)
        
        if self.observation_dist == "nb":
            mu = F.softmax(self.mu_layer(h), dim=1)
            mu = mu * self.library_size
            
            sigma_square = F.softplus(self.sigma_square_layer(h))
            sigma_square = torch.mean(sigma_square, dim=0)
        else:
            mu = self.mu_layer(h)
            sigma_square = F.softplus(self.sigma_square_layer(h))
        
        sigma_square = torch.clamp(sigma_square, self.EPS, self.MAX_SIGMA_SQUARE)
        
        return mu, sigma_square
    
    def _clip_min_value(self, x, eps=1e-6):
        return F.relu(x - eps) + eps
    
    def _polar_project(self, x):
        x_norm = torch.sum(torch.square(x), dim=1, keepdim=True)
        x_norm = torch.sqrt(self._clip_min_value(x_norm))
        
        x_unit = x / x_norm
        x_norm = torch.clamp(x_norm, 0, 32)
        
        z = torch.cat([
            torch.cosh(x_norm), 
            torch.sinh(x_norm) * x_unit
        ], dim=1)
        
        return z
    
    def _depth_regularizer(self, x, batch):
        with torch.no_grad():
            rate = x * 0.2
            samples = torch.poisson(rate)
        
        x_perturbed = F.relu(x - samples)
        z_mu_hat, _ = self._encoder(x_perturbed, batch)
        
        mean_diff = torch.sum(torch.pow(self.z_mu - z_mu_hat, 2), dim=1)
        loss = torch.mean(mean_diff)
        
        return loss
    
    def log_likelihood_nb(self, x, mu, sigma_square, eps=1e-10):
        log_theta_mu_eps = torch.log(sigma_square + mu + eps)
        res = (
            torch.lgamma(x + sigma_square) - torch.lgamma(x + 1) - torch.lgamma(sigma_square) +
            sigma_square * (torch.log(sigma_square) - log_theta_mu_eps) +
            x * (torch.log(mu) - log_theta_mu_eps)
        )
        return res
    
    def log_likelihood_student(self, x, mu, sigma_square, df=5.0):
        df_halved = df / 2
        return (
            torch.lgamma(df_halved + 0.5) - torch.lgamma(df_halved) -
            0.5 * torch.log(math.pi * df * sigma_square) -
            (df_halved + 0.5) * torch.log(1 + (x - mu)**2 / (df * sigma_square))
        )
    
    def forward(self, x, batch_id):
        if len(self.n_batch) > 1:
            batch = self.multi_one_hot(batch_id, self.n_batch)
        else:
            batch = F.one_hot(batch_id, self.n_batch[0]).float()
        
        self.library_size = torch.sum(x, dim=1, keepdim=True)
        
        self.z_mu, self.z_sigma_square = self._encoder(x, batch)
        
        # Sample latent variable
        if self.latent_dist == "normal":
            self.q_z = dist.Normal(self.z_mu, torch.sqrt(self.z_sigma_square))
            self.z = self.q_z.rsample()
            self.p_z = dist.Normal(torch.zeros_like(self.z), torch.ones_like(self.z))
            kl = dist.kl_divergence(self.q_z, self.p_z).sum(dim=1)
            self.kl = torch.mean(kl)
        elif self.latent_dist == 'vmf':
            self.q_z = VonMisesFisher(self.z_mu, self.z_sigma_square)
            self.z = self.q_z.sample()
            self.p_z = HypersphericalUniform(self.z_dim - 1, dtype=x.dtype)
            kl = self.q_z.kl_divergence(self.p_z)
            self.kl = torch.mean(kl)
        elif self.latent_dist == 'wn':
            self.q_z = HyperbolicWrappedNorm(self.z_mu, self.z_sigma_square)
            self.z = self.q_z.sample()
            tmp = self._polar_project(torch.zeros_like(self.z_sigma_square))
            self.p_z = HyperbolicWrappedNorm(tmp, torch.ones_like(self.z_sigma_square))
            kl = self.q_z.log_prob(self.z) - self.p_z.log_prob(self.z)
            self.kl = torch.mean(kl)
        else:
            raise NotImplementedError
        
        # Decoder
        self.mu, self.sigma_square = self._decoder(self.z, batch)
        
        # Depth regularization
        self.depth_loss = self._depth_regularizer(x, batch)
        
        # ELBO calculation
        if