In [3]:
import mlx.core as mx
import mlx.nn as nn
from typing import List

In [19]:
class Transpose(nn.Module):
    """This module returns a view of the array input with its dimensions permuted.

    Args:
        dims (List[int]): The desired ordering of dimensions
    """

    def __init__(self, dims: List[int]):
        super().__init__()
        self.dims = dims

    def __call__(self, x: mx.array) -> mx.array:
        return mx.transpose(x, self.dims)

In [20]:
class LayerNorm2d(nn.LayerNorm):
    def forward(self, x: mx.array) -> mx.array:
        x = x.transpose(0, 2, 3, 1)
        x = mx.fast.layer_norm(x, self.weight, self.bias, self.eps)
        x = x.transpose(0, 3, 1, 2)
        return x

In [38]:
block = nn.Sequential(
            Transpose([0, 2, 3, 1]),
            LayerNorm2d(3),
            Transpose([0, 3, 1, 2]),
        )

In [39]:
X = mx.random.normal((100, 3,3,3))
X.shape

(100, 3, 3, 3)

In [40]:
block(X)

array([[[[-0.591596, 0.829973, -0.536998],
         [0.840898, 0.261511, 0.185647],
         [-1.35373, -0.0762603, -0.225166]],
        [[1.40782, 0.576644, 1.40147],
         [-1.40516, -1.33436, 1.12132],
         [0.322642, -1.18481, 1.3217]],
        [[-0.816223, -1.40662, -0.864469],
         [0.56426, 1.07285, -1.30697],
         [1.03109, 1.26107, -1.09653]]],
       [[[-0.970692, -0.134339, -0.48444],
         [1.0329, 0.252414, 1.12976],
         [-1.25011, 1.0297, 0.585882]],
        [[-0.405329, -1.15203, 1.39281],
         [-1.35282, 1.07887, -1.30154],
         [0.0597848, 0.324656, 0.821752]],
        [[1.37602, 1.28637, -0.908373],
         [0.31992, -1.33128, 0.171784],
         [1.19033, -1.35435, -1.40763]]],
       [[[1.36665, 0.0854834, 0.879047],
         [1.16233, -1.3574, -1.03445],
         [0.664237, 1.23067, -0.537278]],
        [[-0.368432, 1.17976, -1.39892],
         [-1.27882, 1.0223, -0.317717],
         [-1.41336, -1.21872, 1.40151]],
        [[-0.99821

In [41]:
X

array([[[[-1.05949, 0.452603, 0.227415],
         [0.596406, -0.645721, 0.309252],
         [0.21465, 0.137628, -0.421939]],
        [[-0.804221, 0.265521, 0.945528],
         [-1.65098, -1.49349, 1.7714],
         [0.952262, -0.38495, 1.13107]],
        [[-1.08816, -1.1991, 0.106102],
         [0.319604, -0.214718, -2.02321],
         [1.26398, 0.768056, -1.29677]]],
       [[[-1.69183, -0.80097, -0.947339],
         [0.94515, 0.883198, 0.770867],
         [-0.306263, 0.552916, 0.0780474]],
        [[-1.13534, -1.62361, -0.29756],
         [0.42094, 1.77068, -0.394461],
         [-0.250868, -0.0224432, 0.309122]],
        [[0.618029, 0.347448, -1.09408],
         [0.788487, -0.817448, 0.311708],
         [-0.203058, -1.39262, -1.87494]]],
       [[[1.71975, -0.236262, 0.159263],
         [1.47365, -1.02568, 1.1077],
         [-0.0954925, 1.18954, -0.0364512]],
        [[0.653519, 0.73636, -1.6615],
         [-1.95671, 0.532494, 1.26596],
         [-2.02499, -1.19534, 0.690262]],
     