In [None]:
import numpy as np
import torch
import torch.nn as nn

In [None]:
EPS = 1e-8

In [None]:
"""
    Based on https://github.com/naplab/Conv-TasNet/blob/master/utility/models.py#L8-L45
"""
class OriginalCumlativeLayerNorm(nn.Module):
    def __init__(self, dimension, eps=EPS, trainable=True):
        super(OriginalCumlativeLayerNorm, self).__init__()
        
        self.eps = eps
        if trainable:
            self.gain = nn.Parameter(torch.ones(1, dimension, 1))
            self.bias = nn.Parameter(torch.zeros(1, dimension, 1))
        else:
            self.gain = Variable(torch.ones(1, dimension, 1), requires_grad=False)
            self.bias = Variable(torch.zeros(1, dimension, 1), requires_grad=False)

    def forward(self, input):
        # input size: (Batch, Freq, Time)
        # cumulative mean for each time step
        
        batch_size = input.size(0)
        channel = input.size(1)
        time_step = input.size(2)
        
        step_sum = input.sum(1)  # B, T
        step_pow_sum = input.pow(2).sum(1)  # B, T
        cum_sum = torch.cumsum(step_sum, dim=1)  # B, T
        cum_pow_sum = torch.cumsum(step_pow_sum, dim=1)  # B, T
        
        entry_cnt = np.arange(channel, channel*(time_step+1), channel)
        entry_cnt = torch.from_numpy(entry_cnt).type(input.type())
        entry_cnt = entry_cnt.view(1, -1).expand_as(cum_sum)
        
        cum_mean = cum_sum / entry_cnt  # B, T
        cum_var = (cum_pow_sum - 2*cum_mean*cum_sum) / entry_cnt + cum_mean.pow(2)  # B, T
        cum_std = (cum_var + self.eps).sqrt()  # B, T
        
        cum_mean = cum_mean.unsqueeze(1)
        cum_std = cum_std.unsqueeze(1)
        
        x = (input - cum_mean.expand_as(input)) / cum_std.expand_as(input)

        return x * self.gain.expand_as(x).type(x.type()) + self.bias.expand_as(x).type(x.type())

In [None]:
"""
    Based on https://github.com/tky823/DNN-based_source_separation/blob/main/src/norm.py#L42-L103
"""
class MyCumulativeLayerNorm(nn.Module):
    def __init__(self, num_features, eps=EPS):
        super().__init__()
        
        self.num_features = num_features
        self.eps = eps

        self.gamma = nn.Parameter(torch.Tensor(1, num_features, 1))
        self.beta = nn.Parameter(torch.Tensor(1, num_features, 1))
        
        self._reset_parameters()
        
    def _reset_parameters(self):
        self.gamma.data.fill_(1)
        self.beta.data.zero_()
        
    def forward(self, input):
        """
        Args:
            input (batch_size, C, T) or (batch_size, C, S, chunk_size):
        Returns:
            output (batch_size, C, T) or (batch_size, C, S, chunk_size): same shape as the input
        """
        eps = self.eps

        n_dim = input.dim()

        if n_dim == 3:
            batch_size, C, T = input.size()
        elif n_dim == 4:
            batch_size, C, S, chunk_size = input.size()
            T = S * chunk_size
            input = input.view(batch_size, C, T)
        else:
            raise ValueError("Only support 3D or 4D input, but given {}D".format(input.dim()))
        
        step_sum = input.sum(dim=1) # -> (batch_size, T)
        input_pow = input**2
        step_pow_sum = input_pow.sum(dim=1) # -> (batch_size, T)
        cum_sum = torch.cumsum(step_sum, dim=1) # -> (batch_size, T)
        cum_squared_sum = torch.cumsum(step_pow_sum, dim=1) # -> (batch_size, T)
        
        cum_num = torch.arange(C, C*(T+1), C, dtype=torch.float) # -> (T, ): [C, 2*C, ..., T*C]
        cum_mean = cum_sum / cum_num # (batch_size, T)
        cum_squared_mean = cum_squared_sum / cum_num
        cum_var = cum_squared_mean - cum_mean**2
        
        cum_mean = cum_mean.unsqueeze(dim=1)
        cum_var = cum_var.unsqueeze(dim=1)
        
        output = (input - cum_mean) / (torch.sqrt(cum_var) + eps) * self.gamma + self.beta

        if n_dim == 4:
            output = output.view(batch_size, C, S, chunk_size)
        
        return output
    
    def __repr__(self):
        s = '{}'.format(self.__class__.__name__)
        s += '({num_features}, eps={eps})'
        
        return s.format(**self.__dict__)

In [None]:
torch.manual_seed(111)
B, C, T = 4, 3, 10

In [None]:
original_layer_norm = OriginalCumlativeLayerNorm(C)
my_layer_norm = MyCumulativeLayerNorm(C)

In [None]:
input = torch.randn(B, C, T)
original_output = original_layer_norm(input)
my_output = my_layer_norm(input)

In [None]:
torch.allclose(original_output, my_output)