In [1]:
#convert

# babilim.model.layers.batch_normalization

> Apply batch normalization to a tensor.

In [2]:
#export
from babilim.core.annotations import RunOnlyOnce
from babilim.core.module_native import ModuleNative

In [3]:
#export
class BatchNormalization(ModuleNative):
    def __init__(self):
        """
        A batch normalization layer.
        """
        super().__init__()
        
    @RunOnlyOnce
    def _build_pytorch(self, features):
        import torch
        from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d
        if len(features.shape) == 2 or len(features.shape) == 3:
            self.bn = BatchNorm1d(num_features=features.shape[1])
        elif len(features.shape) == 4:
            self.bn = BatchNorm2d(num_features=features.shape[1])
        elif len(features.shape) == 5:
            self.bn = BatchNorm3d(num_features=features.shape[1])
        else:
            raise RuntimeError("Batch norm not available for other input shapes than [B,L], [B,C,L], [B,C,H,W] or [B,C,D,H,W] dimensional.")
        
        if torch.cuda.is_available():
            self.bn = self.bn.to(torch.device("cuda"))  # TODO move to correct device
        
    def _call_pytorch(self, features):
        return self.bn(features)
    
    @RunOnlyOnce
    def _build_tf(self, features):
        from tensorflow.keras.layers import BatchNormalization as _BN
        self.bn = _BN()
        
    def _call_tf(self, features):
        return self.bn(features)

In [4]:
from babilim.core.tensor import Tensor
import numpy as np

batch_norm = BatchNormalization()
tensor = Tensor(data=np.array([[10,3,-4,2], [5, 5, 4, -2], [1,-7,2,0]], dtype=np.float32), trainable=False)

print(tensor.shape)
print(tensor)
result = batch_norm(tensor)
print(tensor.shape)
print(result)

(3, 4)
tensor([[10.,  3., -4.,  2.],
        [ 5.,  5.,  4., -2.],
        [ 1., -7.,  2.,  0.]], device='cuda:0')
(3, 4)
tensor([[ 1.2675,  0.5080, -1.3728,  1.2247],
        [-0.0905,  0.8890,  0.9806, -1.2247],
        [-1.1770, -1.3970,  0.3922,  0.0000]], device='cuda:0',
       grad_fn=<CudnnBatchNormBackward>)
