In [1]:
# -*- coding: utf-8 -*-

import pandas as pd
import altair as alt
import torch
import torch.nn as nn
import warnings
import math

# Set to False to skip notebook execution (e.g. for debugging)
warnings.filterwarnings("ignore")
RUN_EXAMPLES = True

def subsequent_mask(size):
    """Mask out subsequent positions."""
    attn_shape = (1, size, size)
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
        torch.uint8
    )
    return subsequent_mask == 0

class PositionalEncoding(nn.Module):
    """Implement the PE function."""

    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)


def is_interactive_notebook():
    return __name__ == "__main__"


def show_example(fn, args=[]):
    if __name__ == "__main__" and RUN_EXAMPLES:
        return fn(*args)


def execute_example(fn, args=[]):
    if __name__ == "__main__" and RUN_EXAMPLES:
        fn(*args)



def example_mask():
    LS_data = pd.concat(
        [
            pd.DataFrame(
                {
                    "Subsequent Mask": subsequent_mask(20)[0][x, y].flatten(),
                    "Window": y,
                    "Masking": x,
                }
            )
            for y in range(20)
            for x in range(20)
        ]
    )

    return (
        alt.Chart(LS_data)
        .mark_rect()
        .properties(height=250, width=250)
        .encode(
            alt.X("Window:O"),
            alt.Y("Masking:O"),
            alt.Color("Subsequent Mask:Q", scale=alt.Scale(scheme="viridis")),
        )
        .interactive()
    )



def example_positional():
    pe = PositionalEncoding(20, 0)
    y = pe.forward(torch.zeros(1, 100, 20))

    data = pd.concat(
        [
            pd.DataFrame(
                {
                    "embedding": y[0, :, dim],
                    "dimension": dim,
                    "position": list(range(100)),
                }
            )
            for dim in [4, 5, 6, 7]
        ]
    )

    return (
        alt.Chart(data)
        .mark_line()
        .properties(width=800)
        .encode(x="position", y="embedding", color="dimension:N")
        .interactive()
    )


In [2]:
example_mask()

In [4]:
example_positional()

In [5]:
batch, sentence_length, embedding_dim = 20, 5, 10
embedding = torch.randn(batch, sentence_length, embedding_dim)
layer_norm = nn.LayerNorm(embedding_dim)
layer_norm(embedding)

tensor([[[-0.2552,  0.0339, -1.4545, -0.7597,  2.2567, -0.4952,  1.2078,
          -0.2459, -0.5578,  0.2699],
         [-0.3009, -2.1998,  1.6541, -0.3955,  0.8265,  0.2539, -0.0366,
           0.3005, -0.8676,  0.7654],
         [ 2.2809,  1.0522,  0.0943, -0.4245, -0.3283, -1.3819, -0.2123,
          -0.8842, -0.6627,  0.4664],
         [-1.6838, -1.4295,  0.4746, -0.0926,  0.7973, -1.0922,  0.3282,
           0.3866,  1.4113,  0.9001],
         [ 1.9142,  0.9444,  0.9031, -1.3955, -1.1987, -0.7263, -0.3427,
          -0.1494,  0.5617, -0.5109]],

        [[-2.3236, -0.2702, -0.1060, -0.6588,  1.0834,  1.6733,  0.1258,
           0.0687,  0.1505,  0.2569],
         [-1.2379,  1.6578,  0.9486, -1.3408, -0.9931,  0.2944,  0.4799,
          -0.0902,  1.0546, -0.7732],
         [-0.7235,  0.3261,  1.3529, -1.0218, -1.9142, -0.1949,  0.4771,
           1.5967, -0.0239,  0.1256],
         [-0.0650, -0.0946,  0.5021, -1.8147,  1.3357, -0.6814,  1.1816,
          -0.1823,  1.0813, -1.2627],