In [37]:
import torch
import torch.nn as nn
from layers import Lorentz
from geoopt import ManifoldParameter

class LorentzBatchNormTheirs(nn.Module):
    """ Lorentz Batch Normalization with Centroid and Fréchet variance
    """
    def __init__(self, manifold: Lorentz, num_features: int):
        super(LorentzBatchNormTheirs, self).__init__()
        self.manifold = manifold

        self.beta = ManifoldParameter(self.manifold.origin(num_features), manifold=self.manifold)
        self.gamma = torch.nn.Parameter(torch.ones((1,)))
        self.eps = 1e-5

        # running statistics
        self.register_buffer('running_mean', torch.zeros(num_features - 1))
        self.register_buffer('running_var', torch.ones((1,)))

    def forward(self, x, momentum=0.1):
        assert (len(x.shape)==2) or (len(x.shape)==3), "Wrong input shape in Lorentz batch normalization."
        beta = self.beta

        if self.training:
            # Compute batch mean
            mean = self.manifold.centroid(x)
            if len(x.shape) == 3:
                mean = self.manifold.centroid(mean)
            # Transport batch to origin (center batch)
            x_T = self.manifold.logmap(mean, x)
            x_T = self.manifold.transp0back(mean, x_T)

            # Compute Fréchet variance
            if len(x.shape) == 3:
                var = torch.mean(torch.norm(x_T, dim=-1), dim=(0,1))
            else:
                var = torch.mean(torch.norm(x_T, dim=-1), dim=0)

            # Rescale batch
            x_T = x_T*(self.gamma/(var+self.eps))

            # Transport batch to learned mean
            x_T = self.manifold.transp0(beta, x_T)
            output = self.manifold.expmap(beta, x_T)
            

            # Save running parameters
            with torch.no_grad():
                running_mean = self.manifold.expmap0(self.running_mean)
                means = torch.concat((running_mean.unsqueeze(0), mean.detach().unsqueeze(0)), dim=0)
                self.running_mean.copy_(
                    self.manifold.logmap0(
                        self.manifold.centroid(
                            means,
                            weights=torch.tensor(((1 - momentum), momentum), device=means.device),
                        )
                    )
                )
                self.running_var.copy_((1 - momentum)*self.running_var + momentum*var.detach())

        else:
            # Transport batch to origin (center batch)
            running_mean = self.manifold.expmap0(self.running_mean)
            x_T = self.manifold.logmap(running_mean, x)
            x_T = self.manifold.transp0back(running_mean, x_T)

            # Rescale batch
            x_T = x_T*(self.gamma/(self.running_var+self.eps))

            # Transport batch to learned mean
            x_T = self.manifold.transp0(beta, x_T)
            output = self.manifold.expmap(beta, x_T)

        return output

class LorentzBatchNorm1d(LorentzBatchNormTheirs):
    """ 1D Lorentz Batch Normalization with Centroid and Fréchet variance
    """
    def __init__(self, manifold: Lorentz, num_features: int):
        super(LorentzBatchNorm1d, self).__init__(manifold, num_features)

    def forward(self, x, momentum=0.1):
        return super(LorentzBatchNorm1d, self).forward(x, momentum)

class LorentzBatchNorm2dTheirs(LorentzBatchNormTheirs):
    """ 2D Lorentz Batch Normalization with Centroid and Fréchet variance
    """
    def __init__(self, manifold: Lorentz, num_features: int):
        super(LorentzBatchNorm2dTheirs, self).__init__(manifold, num_features)

    def forward(self, x, momentum=0.1):
        """ x has to be in channel last representation -> Shape = bs x H x W x C """
        bs, c, h, w = x.shape
        x = x.permute(0, 2, 3, 1).reshape(bs, -1, c)
        x = super(LorentzBatchNorm2dTheirs, self).forward(x, momentum)
        x = x.reshape(bs, h, w, c).permute(0, 3, 1, 2)

        return x

class LorentzBatchNormOurs(nn.Module):
    """
    Lorentz Batch Normalization following Bdeir et al.
    Simplified to use manifold primitives.
    """
    
    def __init__(
        self,
        num_features: int,
        manifold: Lorentz = None,
        momentum: float = 0.1,
        eps: float = 1e-5,
    ):
        super().__init__()
        self.manifold = manifold or Lorentz(k=1.0)
        self.num_features = num_features
        self.momentum = momentum
        self.eps = eps
        
        # Learnable scale (positive real)
        self.beta = ManifoldParameter(self.manifold.origin(num_features), manifold=self.manifold)
        self.gamma = nn.Parameter(torch.ones((1,)))
        
        # Running statistics (store space components of centroid)
        self.register_buffer('running_mean', torch.zeros(num_features - 1))
        self.register_buffer('running_var', torch.ones(1,))

    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, L, C = x.shape
        if self.training:
            mean = self.manifold.centroid(x)
            if len(x.shape) == 3:
                mean = self.manifold.centroid(mean)
            origin = self.manifold.origin(C)
            x_T = self.manifold.logmap(mean, x)
            x_T = self.manifold.parallel_transport(mean, x_T, origin)

            # Compute Fréchet variance
            if len(x.shape) == 3:
                var = torch.mean(torch.norm(x_T, dim=-1), dim=(0,1))
            else:
                var = torch.mean(torch.norm(x_T, dim=-1), dim=0)
            # Rescale batch
            x_T = x_T*(self.gamma/(var+self.eps))


            # Transport batch to learned mean
            x_T = self.manifold.parallel_transport(origin, x_T, self.beta)
            output = self.manifold.expmap(self.beta, x_T)

            with torch.no_grad():
                running_mean = self.manifold.expmap0(self.running_mean)
                means = torch.concat((running_mean.unsqueeze(0), mean.detach().unsqueeze(0)), dim=0)
                self.running_mean.copy_( ## WHY THIS??
                    self.manifold.logmap0(
                        self.manifold.centroid(
                            means,
                            weights=torch.tensor(((1 - self.momentum), self.momentum), device=means.device),
                        )
                    )
                )
                self.running_var.copy_((1 - self.momentum)*self.running_var + self.momentum*var.detach())
        else:
            # Transport batch to origin (center batch)
            origin = self.manifold.origin(C)
            running_mean = self.manifold.expmap0(self.running_mean)
            x_T = self.manifold.logmap(running_mean, x)
            x_T = self.manifold.parallel_transport(running_mean, x_T, origin)

            # Rescale batch
            x_T = x_T*(self.gamma/(self.running_var+self.eps))

            # Transport batch to learned mean
            x_T = self.manifold.parallel_transport(origin, x_T, self.beta)
            output = self.manifold.expmap(self.beta, x_T)


        return output
    
class LorentzBatchNorm2dOurs(LorentzBatchNormOurs):
    """ 2D Lorentz Batch Normalization with Centroid and Fréchet variance
    """
    def __init__(self, manifold: Lorentz, num_features: int):
        super(LorentzBatchNorm2dOurs, self).__init__(num_features, manifold)

    def forward(self, x, momentum=0.1):
        """ x has to be in channel last representation -> Shape = bs x H x W x C """
        bs, c, h, w = x.shape
        x = x.permute(0, 2, 3, 1).reshape(bs, -1, c)
        x = super(LorentzBatchNorm2dOurs, self).forward(x)
        x = x.reshape(bs, h, w, c).permute(0, 3, 1, 2)

        return x


In [38]:
man = Lorentz(k=1.0)
our_bnorm = LorentzBatchNorm2dOurs(num_features=4, manifold=man)
inp = torch.randn(4, 3, 2, 2)
inp_on_man = man.projection_space_orthogonal(inp, manifold_dim=1)

In [39]:
my_output = our_bnorm(inp_on_man)


In [40]:
our_bnorm.running_mean

tensor([-0.0031,  0.0160, -0.0138])

In [41]:
their_bnorm = LorentzBatchNorm2dTheirs(manifold=man, num_features=4)
their_bnorm(inp_on_man)[0, :, 0, 0]

tensor([ 1.2451, -0.1815,  0.2542, -0.6729], grad_fn=<SelectBackward0>)

In [42]:
their_bnorm.running_mean

tensor([-0.0031,  0.0160, -0.0138])

In [43]:
our_bnorm.eval()
inp = torch.randn(4, 3, 2, 2)
inp_on_man = man.projection_space_orthogonal(inp, manifold_dim=1)

In [44]:
my_output = our_bnorm(inp_on_man)
my_output

tensor([[[[ 1.1451e+00,  2.0506e+00],
          [ 1.7814e+00,  2.1844e+00]],

         [[-1.3016e-01,  9.7539e-02],
          [-9.9245e-01,  1.8522e-01]],

         [[ 3.5514e-02,  1.7321e+00],
          [-1.5268e-01, -1.1098e+00]],

         [[ 5.4127e-01, -4.4209e-01],
          [-1.0795e+00, -1.5830e+00]]],


        [[[ 1.7275e+00,  1.4228e+00],
          [ 1.7409e+00,  1.0843e+00]],

         [[ 1.3762e+00,  6.4639e-01],
          [-1.1801e-01, -4.0858e-02]],

         [[ 2.8431e-01, -3.4420e-01],
          [-9.3581e-01, -2.3233e-01]],

         [[ 9.8447e-02,  6.9852e-01],
          [-1.0681e+00, -3.4638e-01]]],


        [[[ 1.1432e+00,  3.1658e+00],
          [ 1.6041e+00,  1.9990e+00]],

         [[ 3.2754e-01,  2.5240e+00],
          [-1.0612e+00,  1.6859e+00]],

         [[ 1.2658e-01,  1.0223e+00],
          [ 2.2129e-03, -7.6190e-02]],

         [[-4.2844e-01,  1.2676e+00],
          [-6.6860e-01, -3.8480e-01]]],


        [[[ 1.8585e+00,  2.0012e+00],
          [ 1.7349e+

In [45]:
our_bnorm.running_mean

tensor([-0.0031,  0.0160, -0.0138])

In [46]:
their_bnorm.eval()
their_bnorm(inp_on_man)

tensor([[[[ 1.1451e+00,  2.0506e+00],
          [ 1.7814e+00,  2.1844e+00]],

         [[-1.3016e-01,  9.7539e-02],
          [-9.9245e-01,  1.8522e-01]],

         [[ 3.5514e-02,  1.7321e+00],
          [-1.5268e-01, -1.1098e+00]],

         [[ 5.4127e-01, -4.4209e-01],
          [-1.0795e+00, -1.5830e+00]]],


        [[[ 1.7275e+00,  1.4228e+00],
          [ 1.7409e+00,  1.0843e+00]],

         [[ 1.3762e+00,  6.4639e-01],
          [-1.1801e-01, -4.0858e-02]],

         [[ 2.8431e-01, -3.4420e-01],
          [-9.3581e-01, -2.3233e-01]],

         [[ 9.8447e-02,  6.9852e-01],
          [-1.0681e+00, -3.4638e-01]]],


        [[[ 1.1432e+00,  3.1658e+00],
          [ 1.6041e+00,  1.9990e+00]],

         [[ 3.2754e-01,  2.5240e+00],
          [-1.0612e+00,  1.6859e+00]],

         [[ 1.2658e-01,  1.0223e+00],
          [ 2.2129e-03, -7.6190e-02]],

         [[-4.2844e-01,  1.2676e+00],
          [-6.6860e-01, -3.8480e-01]]],


        [[[ 1.8585e+00,  2.0012e+00],
          [ 1.7349e+

In [47]:
their_bnorm.running_mean

tensor([-0.0031,  0.0160, -0.0138])