In [3]:
%%html
<a href="https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html">LayerNorm</a>

In [5]:
import torch
import torch.nn as nn
from torch.nn.functional import normalize
from typing import Tuple


def my_layer_norm(
    x: torch.Tensor, dim: Tuple[int], eps: float = 0.00001
) -> torch.Tensor:
    mean = torch.mean(x, dim=dim, keepdim=True)
    var = torch.square(x - mean).mean(dim=dim, keepdim=True)
    return (x - mean) / torch.sqrt(var + eps)


batch, sentence_length, embedding_dim = 1, 3, 3
embedding = torch.tensor(
    [[[1.0, 2.0, 3.0],
     [2.0, 3.0, 4.0],
     [1.0, 4.0, 6.0]]])
print(embedding.shape)
layer_norm = nn.LayerNorm(3, elementwise_affine=True)
output = layer_norm(embedding)
print(output)

print(my_layer_norm(torch.tensor([1.0, 2.0, 3.0]), 0))
print(my_layer_norm(torch.tensor([2.0, 3.0, 4.0]), 0))
print(my_layer_norm(torch.tensor([1.0, 4.0, 6.0]), 0))

torch.Size([1, 3, 3])
tensor([[[-1.2247e+00,  0.0000e+00,  1.2247e+00],
         [-1.2247e+00,  1.1921e-07,  1.2247e+00],
         [-1.2978e+00,  1.6222e-01,  1.1355e+00]]],
       grad_fn=<NativeLayerNormBackward0>)
tensor([-1.2247,  0.0000,  1.2247])
tensor([-1.2247,  0.0000,  1.2247])
tensor([-1.2978,  0.1622,  1.1355])


In [7]:
import torch
a = torch.Tensor([[1.0, 2.0], [3.0, 4.0]])
print(a.shape)
print(torch.mean(a))
print(torch.mean(a, 0, True))
print(torch.mean(a, 1, True))
print(torch.mean(a, 0, False))
print(torch.mean(a, 1, False))

torch.Size([2, 2])
tensor(2.5000)
tensor([[2., 3.]])
tensor([[1.5000],
        [3.5000]])
tensor([2., 3.])
tensor([1.5000, 3.5000])


In [2]:
help(torch.mean)

Help on built-in function mean in module torch:

mean(...)
    mean(input, *, dtype=None) -> Tensor
    
    Returns the mean value of all elements in the :attr:`input` tensor.
    
    Args:
        input (Tensor): the input tensor.
    
    Keyword args:
        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
            If specified, the input tensor is casted to :attr:`dtype` before the operation
            is performed. This is useful for preventing data type overflows. Default: None.
    
    Example::
    
        >>> a = torch.randn(1, 3)
        >>> a
        tensor([[ 0.2294, -0.5481,  1.3288]])
        >>> torch.mean(a)
        tensor(0.3367)
    
    .. function:: mean(input, dim, keepdim=False, *, dtype=None, out=None) -> Tensor
       :noindex:
    
    Returns the mean value of each row of the :attr:`input` tensor in the given
    dimension :attr:`dim`. If :attr:`dim` is a list of dimensions,
    reduce over all of them.
    
    
    I

In [2]:
import torch

ts = torch.randn((2,3,2,2), dtype=torch.float)
mask = torch.tensor([[
    [[1, 0],
    [0, 1]],
]])
print(mask.shape)
print(mask == 0)
print("before mask:")
print(ts)
ts = ts.masked_fill(mask==0, float('-inf'))
print("after maks:")
print(ts)

# att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))

torch.Size([1, 1, 2, 2])
tensor([[[[False,  True],
          [ True, False]]]])
before mask:
tensor([[[[-0.4132, -0.2266],
          [ 0.2604,  1.0856]],

         [[-2.3782, -0.5554],
          [ 0.6914, -0.2681]],

         [[ 0.9550, -0.7643],
          [-1.2810,  0.0563]]],


        [[[-0.0229,  0.9606],
          [-0.6099,  0.1671]],

         [[-0.3016, -1.1551],
          [ 1.0394,  0.8793]],

         [[ 0.2370, -0.2745],
          [-0.8201, -2.4180]]]])
after maks:
tensor([[[[-0.4132,    -inf],
          [   -inf,  1.0856]],

         [[-2.3782,    -inf],
          [   -inf, -0.2681]],

         [[ 0.9550,    -inf],
          [   -inf,  0.0563]]],


        [[[-0.0229,    -inf],
          [   -inf,  0.1671]],

         [[-0.3016,    -inf],
          [   -inf,  0.8793]],

         [[ 0.2370,    -inf],
          [   -inf, -2.4180]]]])
