# 绝对位置编码——sinusoidal

> 出自论文 `Attention is all your need`

In [None]:
import math
import torch
import torch.nn as nn
import numpy as np

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout_prob: float, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout_prob)
        encodings = self.get_positional_encoding(d_model, max_len)
        self.register_buffer('positional_encodings', encodings, False)

    @staticmethod
    def get_positional_encoding(d_model: int, max_len: int):
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        two_i = torch.arange(0, d_model, 2, dtype=torch.float32)
        div_term = torch.exp(two_i * -(math.log(10000.0) / d_model))
        encodings = torch.zeros(max_len, d_model)
        encodings[:, 0::2] = torch.sin(position * div_term)
        encodings[:, 1::2] = torch.cos(position * div_term)
        return encodings.unsqueeze(0).requires_grad_(False)

    def forward(self, x: torch.Tensor):
        pe = self.positional_encodings[:x.shape[1]].detach().requires_grad_(False)
        return self.dropout(x + pe)

def _test_positional_encoding():
    import matplotlib.pyplot as plt
    plt.figure(figsize=(15, 5))
    pe = PositionalEncoding.get_positional_encoding(20, 100)
    print(pe.shape)
    plt.plot(np.arange(100), pe[:, 0, 4:8].numpy())
    plt.legend(["dim %d" % p for p in [4, 5, 6, 7]])
    plt.title("Positional encoding")
    plt.show()

if __name__ == '__main__':
    _test_positional_encoding()

# 相对位置编码
