In [1]:
#convert

# babilim.model.layers.reshape

> Reshape a tensor.

In [2]:
#export
from babilim.core.annotations import RunOnlyOnce
from babilim.core.module_native import ModuleNative

In [3]:
#export
class Reshape(ModuleNative):
    def __init__(self, output_shape):
        """
        Reshape a tensor.
    
        A tensor of shape (B, ?) where B is the batch size gets reshaped into (B, output_shape[0], output_shape[1], ...) where the batch size is kept and all other dimensions are depending on output_shape.

        :param output_shape: The shape that the tensor should have after reshaping is (batch_size,) + output_shape (meaning batch size is automatically kept).
        """
        super().__init__()
        self.output_shape = output_shape
        
    @RunOnlyOnce
    def _build_pytorch(self, features):
        self.output_shape = list(self.output_shape)
        
    def _call_pytorch(self, features):
        shape = [features.shape[0]] + self.output_shape
        return features.view(shape)
    
    @RunOnlyOnce
    def _build_tf(self, features):
        from tensorflow.keras.layers import Reshape as _Reshape
        self.reshape = _Reshape(self.output_shape)
        
    def _call_tf(self, features):
        return self.reshape(features)

In [4]:
from babilim.core.tensor import Tensor
import numpy as np

reshape = Reshape(output_shape=(8, 24))
tensor = Tensor(data=np.zeros((10,8,8,3), dtype=np.float32), trainable=False)

print(tensor.shape)
result = reshape(tensor)
print(result.shape)

(10, 8, 8, 3)
(10, 8, 24)
