In [9]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
def print_stat(x, dist=False):
    total_number = None
    distribution = None
    if dist:
        total_number = x.numel()
        distribution = torch.histc(x, bins=10, min=float(x.min()), max=float(x.max()))
    if isinstance(x, torch.Tensor):
        print(
            f"min = {x.min().data.item():-15f} max = {x.max().data.item():-15f} mean = {x.mean().data.item():-15f} std = {x.std().data.item():-15f}\n total num = {total_number} distribution = {distribution}"
        )
    elif isinstance(x, np.ndarray):
        print(
            f"min = {np.min(x):-15f} max = {np.max(x):-15f} mean = {np.mean(x):-15f} std = {np.std(x):-15f}"
        )
        
class MyLayerNorm(nn.Module):
    r"""LayerNorm implementation used in ConvNeXt
    LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
    with shape (batch_size, channels, height, width).
    """

    def __init__(
        self,
        normalized_shape,
        eps=1e-6,  # TODO use small dataset to find if -6 is a good order of magnitude
        data_format="channels_last",
        reshape_last_to_first=False,
        interpolate=False,
        is_linear: bool = False,
    ):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape,)
        self.reshape_last_to_first = reshape_last_to_first
        self.interpolate = interpolate

    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(
                x, self.normalized_shape, self.weight, self.bias, self.eps
            )
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            if self.interpolate:
                weight = self.weight.unsqueeze(0).unsqueeze(0)
                weight = F.interpolate(
                    weight, size=x.shape[1], mode="linear", align_corners=True
                )
                weight = weight.squeeze()
                bias = self.bias.unsqueeze(0).unsqueeze(0)
                bias = F.interpolate(
                    bias, size=x.shape[1], mode="linear", align_corners=True
                )
                bias = bias.squeeze()
            else:
                weight = self.weight
                bias = self.bias
            if len(x.shape) == 4:
                x = weight[:, None, None] * x + bias[:, None, None]
            elif len(x.shape) == 5:
                x = weight[:, None, None, None] * x + bias[:, None, None, None]
            elif len(x.shape) == 2:
                x = weight[:] * x + bias[:]
            return x

    def extra_repr(self) -> str:
        s = "eps=" + str(self.eps)
        s += ", data_format=" + str(self.data_format)
        return s

In [45]:
test_conv2d = nn.Sequential(
    nn.Conv2d(50, 50, 3, padding=1, bias=True),
    nn.Conv2d(50, 50, 3, padding=1, groups=50, bias=True),
    nn.GELU(),
    MyLayerNorm(50, data_format="channels_first", eps=0),
    )
random_seed = 1234
torch.manual_seed(random_seed)
input_tensor = torch.randn(1, 50, 64, 64)
input0p1 = input_tensor*0.1
input0p5 = input_tensor*0.5
input5 = input_tensor*5
input10 = input_tensor*10
out = test_conv2d(input_tensor)
out0p1 = test_conv2d(input0p1)
out0p5 = test_conv2d(input0p5)
out5 = test_conv2d(input5)
out10 = test_conv2d(input10)
print_stat(out0p1)
print_stat(out0p5)
print_stat(out)
print_stat(out5)
print_stat(out10)

min =       -1.702983 max =        5.690491 mean =       -0.000000 std =        1.000002
 total num = None distribution = None
min =       -1.122204 max =        5.745362 mean =        0.000000 std =        1.000002
 total num = None distribution = None
min =       -1.094324 max =        5.675197 mean =        0.000000 std =        1.000002
 total num = None distribution = None
min =       -1.097608 max =        5.702815 mean =       -0.000000 std =        1.000002
 total num = None distribution = None
min =       -1.095659 max =        5.714510 mean =       -0.000000 std =        1.000002
 total num = None distribution = None
