In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def positional_signal(hidden_size,
                      length,
                      min_timescale=1.0,
                      max_timescale=1e4):
    """
      Helper function, constructing basic positional encoding.
      The code is partially based on implementation from Tensor2Tensor library
      https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py
      """

    import numpy as np

    if hidden_size % 2 != 0:
        raise ValueError(
            f"The hidden dimension of the model must be divisible by 2. Currently it is {hidden_size}"
        )

    position = torch.tensor(np.arange(0, length), dtype=torch.float)
    num_timescales = int(hidden_size // 2)
    log_timescale_increment = (
        np.log(float(max_timescale) / float(min_timescale)) /
        (num_timescales - 1))

    # inv_timescales = (min_timescale * tf.keras.backend.exp(
    #     tf.keras.backend.arange(num_timescales, dtype=tf.keras.backend.floatx())
    #     * -log_timescale_increment))

    rangess = np.arange(0, num_timescales) * (-log_timescale_increment)
    inv_timescales = torch.tensor(min_timescale * np.exp(rangess),
                                  dtype=torch.float)

    # inv_timescales = torch.tensor((min_timescale * torch.exp(
    #     torch.mul(
    #         torch.tensor(np.arange(0, num_timescales) * (-log_timescale_increment)))
    scaled_time = torch.mul(torch.unsqueeze(position, dim=1),
                            torch.unsqueeze(inv_timescales, dim=0))
    signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
    signal = torch.unsqueeze(signal, dim=0)

    return signal

signal = positional_signal(hidden_size=768, length=128)
signal

tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  1.0000,  1.0000,  1.0000],
         [ 0.8415,  0.8284,  0.8152,  ...,  1.0000,  1.0000,  1.0000],
         [ 0.9093,  0.9280,  0.9443,  ...,  1.0000,  1.0000,  1.0000],
         ...,
         [-0.6160,  0.4726, -0.2476,  ...,  0.9999,  0.9999,  0.9999],
         [ 0.3300, -0.4653,  0.6464,  ...,  0.9999,  0.9999,  0.9999],
         [ 0.9726, -0.9939,  0.9964,  ...,  0.9999,  0.9999,  0.9999]]])

In [13]:
t = torch.tensor([[1, 2], [3, 4]])
torch.gather(t, 0, torch.tensor([[0, 0], [0, 0]]))

tensor([[1, 2],
        [1, 2]])