In [1]:
#convert

# babilim.model.layers.tensor_combiners

> Ways of combining tensors.

In [2]:
#export
from babilim.core.annotations import RunOnlyOnce
from babilim.core.module_native import ModuleNative

In [3]:
#export
class Stack(ModuleNative):
    def __init__(self, axis):
        """
        Stack layers along an axis.

        Creates a callable object with the following signature:
        * **tensor_list**: (List[Tensor]) The tensors that should be stacked. A list of length S containing Tensors.
        * **return**: A tensor of shape [..., S, ...] where the position at which S is in the shape is equal to the axis.

        Parameters of the constructor.
        :param axis: (int) The axis along which the stacking happens.
        """
        super().__init__()
        self.axis = axis
        
    @RunOnlyOnce
    def _build_pytorch(self, tensor_list):
        pass
        
    def _call_pytorch(self, tensor_list):
        import torch
        return torch.stack(tensor_list, dim=self.axis)
    
    @RunOnlyOnce
    def _build_tf(self, tensor_list):
        pass
        
    def _call_tf(self, tensor_list):
        import tensorflow as tf
        return tf.stack(tensor_list, axis=self.axis)

In [4]:
from babilim.core.tensor import Tensor
import numpy as np

stack = Stack(axis=1)
tensor1 = Tensor(data=np.zeros((10,8,8,3)), trainable=False)
tensor2 = Tensor(data=np.zeros((10,8,8,3)), trainable=False)

print(tensor1.shape)
print(tensor2.shape)
result = stack([tensor1, tensor2])
print(result.shape)

(10, 8, 8, 3)
(10, 8, 8, 3)
(10, 2, 8, 8, 3)


In [5]:
#export
class Concat(ModuleNative):
    def __init__(self, axis):
        """
        Concatenate layers along an axis.

        Creates a callable object with the following signature:
        * **tensor_list**: (List[Tensor]) The tensors that should be stacked. A list of length S containing Tensors.
        * **return**: A tensor of shape [..., S * inp_tensor.shape[axis], ...] where the position at which S is in the shape is equal to the axis.

        Parameters of the constructor.
        :param axis: (int) The axis along which the concatenation happens.
        """
        super().__init__()
        self.axis = axis
        
    @RunOnlyOnce
    def _build_pytorch(self, tensor_list):
        pass
        
    def _call_pytorch(self, tensor_list):
        import torch
        return torch.cat(tensor_list, dim=self.axis)
    
    @RunOnlyOnce
    def _build_tf(self, tensor_list):
        pass
        
    def _call_tf(self, tensor_list):
        import tensorflow as tf
        return tf.concat(tensor_list, axis=self.axis)

In [6]:
from babilim.core.tensor import Tensor
import numpy as np

stack = Concat(axis=1)
tensor1 = Tensor(data=np.zeros((10,8,8,3)), trainable=False)
tensor2 = Tensor(data=np.zeros((10,8,8,3)), trainable=False)

print(tensor1.shape)
print(tensor2.shape)
result = stack([tensor1, tensor2])
print(result.shape)

(10, 8, 8, 3)
(10, 8, 8, 3)
(10, 16, 8, 3)
