In [1]:
import torch
import torch.nn as nn
import math

from torch.utils.data import DataLoader

In [2]:
class SplitLayer(nn.Module):
    """ 
    An implementation of a neural network layer that given an input X it does the following:
    - Split X into two halves X1, X2
    - Feed X1, X2 into the same LinearLayer yielding outputs Z1, Z2
    - Feed Z1, Z2 into a ReLU layer yielding outputs Y1, Y2
    - Concatenate Y=[Y1, Y2]
    """

    # The number of input batches
    input_batches: int

    # The number of input features.
    # Because we split the input to two halves, this is expected to be even.
    input_features: int

    # Determines the size of each split.
    # It is basically input_features/2
    split_size: int

    # Determines the dimension to split by.
    # It is basically 1 since 1 is the expected features dimension.
    split_dim: int

    def __init__(self, input_batches: int, input_features: int):
        super().__init__()

        assert input_batches > 0, "expected input_batches to be positive"
        assert input_features % 2 == 0, "expected input_features to be even"

        self.input_batches = input_batches
        self.input_features = input_features

        self.split_dim = 1
        self.split_size = self.input_features // 2

        self.hidden = nn.Sequential(
            nn.Linear(self.split_size, self.split_size),
            nn.ReLU())

        self.initialize_weights()

    def initialize_weights(self):
        """
        Weight Initialization using Xavier's method.

        As explained here
        - https://d2l.ai/chapter_multilayer-perceptrons/numerical-stability-and-init.html#xavier-initialization
        
        Xavier's method tackles the issue of exploding/vanishing gradients 
        """
        # sqrt(6 / (split_size + split_size))
        xavier = math.sqrt(3 / self.split_size) 

        with torch.no_grad():
            self.hidden[0].weight.uniform_(-xavier, +xavier)

        # Same as: (in fact I saw how to manually do it there)
        # torch.nn.init.uniform_(self.hidden[0].weight, -xavier, +xavier)

        # Bias is kept as is (0 bias)

    def forward(self, x: torch.Tensor, verbose=False):
        log = print if verbose else lambda x: ()

        log(f"  Input: {x}")

        a_in, b_in = x.split(split_size=self.split_size, dim=self.split_dim)
        log(f"  Split A: {a_in}")
        log(f"  Split B: {b_in}")

        a_out = self.hidden(a_in)
        log(f"  Ouput A: {a_in}")

        b_out = self.hidden(b_in)
        log(f"  Output B: {b_in}")

        y = torch.concat((a_out, b_out), dim=self.split_dim)
        log(f"  Output: {y}")

        return y

In [3]:
def demonstrate_split_layer():
    print("Demonstration of SplitLayer: ")
    
    x = torch.arange(12).reshape((3, 4)).float()
    
    split_layer = SplitLayer(3, 4)
    split_layer.train(False)
    
    split_layer.forward(x, verbose=True)

demonstrate_split_layer()

Demonstration of SplitLayer: 
  Input: tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
  Split A: tensor([[0., 1.],
        [4., 5.],
        [8., 9.]])
  Split B: tensor([[ 2.,  3.],
        [ 6.,  7.],
        [10., 11.]])
  Ouput A: tensor([[0., 1.],
        [4., 5.],
        [8., 9.]])
  Output B: tensor([[ 2.,  3.],
        [ 6.,  7.],
        [10., 11.]])
  Output: tensor([[ 0.3052,  1.7497,  3.3548,  1.6828],
        [ 6.4043,  1.6159,  9.4539,  1.5490],
        [12.5035,  1.4821, 15.5530,  1.4152]], grad_fn=<CatBackward0>)


## SplitLayer illustration

![title](SplitLayer.png)

## Parameters size

A linear layer that maps an $n$ dimensional vector to an $m$ dimensional vector is a matrix of size $n \times m$ and a bias vector of size $m$.

Thus
- To map the original input $X$ of size $M$ we require $M^2 + M$ parameters.
- To map a splitted input of $X$ which is of size $\frac{M}{2}$ we require $\frac{M^2}{4} + \frac{M}{2}$ parameters

Because we feed both splits to the same layer, we use the same parameters and do not instantiate them twice.

To conclude, _SplitLayer_ uses (asymptotically) a quarter of the number of the parameters used by _LinearLayer_.

## Gradient Calculation for 2-split

Denote by $W$ and $b$ the parameters of the _Linear_ layer.

Then, by the chain rule

$$
\frac{\partial C}{\partial W} = \frac{\partial C}{\partial Y} \cdot \begin{bmatrix}
    \frac{\partial Y_1}{\partial Z_1} \cdot \frac{\partial Z_1}{\partial W} \\
    \frac{\partial Y_2}{\partial Z_2} \cdot \frac{\partial Z_2}{\partial W} \\
\end{bmatrix}
~~;~~
\frac{\partial C}{\partial b} = \frac{\partial C}{\partial Y} \cdot \begin{bmatrix}
    \frac{\partial Y_1}{\partial Z_1} \cdot \frac{\partial Z_1}{\partial b} \\
    \frac{\partial Y_2}{\partial Z_2} \cdot \frac{\partial Z_2}{\partial b} \\
\end{bmatrix}
$$

By definition of the _ReLU_ layer $ReLU(x) = \max \{x, 0 \}$ we get the derivatives

$$
    \frac{\partial Y_i}{\partial Z_i} = \delta (Z_i)
$$

where $\delta$ is an elementwise function s.t $\delta(X) = [\delta(x_i)]$ and the scalar $\delta(x_i)$ is defined by

$$
\delta(x_i) = \begin{cases}
    1 & x_i > 0 \\
    0 & otherwise
\end{cases}
$$

By definition of the _Linear_ layer $Z_i = X_i W + b$ we get the derivatives

$$
\frac{\partial Z_i}{W} = X_i ~~;~~ \frac{\partial Z_i}{b} = 1
$$

Denoting the scalar $\frac{\partial C}{\partial Y} = c$ and putting it all together we get

$$
\frac{\partial C}{\partial W} = c \cdot \begin{bmatrix}
    \delta(Z_1) \cdot X_1 \\
    \delta(Z_2) \cdot X_2 \\
\end{bmatrix}
~~;~~
\frac{\partial C}{\partial b} = c \cdot \begin{bmatrix}
    \delta(Z_1) \\
    \delta(Z_2) \\
\end{bmatrix}
$$

i.e an optimization step with learning rate $\alpha$ is done by

$$
W \leftarrow W - \alpha c \cdot \begin{bmatrix}
    \delta(Z_1) \cdot X_1 \\
    \delta(Z_2) \cdot X_2 \\
\end{bmatrix}
~~;~~
b \leftarrow b - \alpha c \cdot \begin{bmatrix}
    \delta(Z_1) \\
    \delta(Z_2) \\
\end{bmatrix}
$$

## Extension of the Gradient for 4-split

The benefit of presenting the derivatives in vector form is that it is clearly generalized.

For an input the is split to 4 components $X_1, X_2, X_3$ and $X_4$ we get

$$
\frac{\partial C}{\partial W} = c \cdot \begin{bmatrix}
    \delta(Z_1) \cdot X_1 \\
    \delta(Z_2) \cdot X_2 \\
    \delta(Z_3) \cdot X_3 \\
    \delta(Z_4) \cdot X_4 \\
\end{bmatrix}
~~;~~
\frac{\partial C}{\partial b} = c \cdot \begin{bmatrix}
    \delta(Z_1) \\
    \delta(Z_2) \\
    \delta(Z_3) \\
    \delta(Z_4) \\
\end{bmatrix}
$$