In [3]:
import torch
from torch import nn

# Time embedding

### Parameters
- `input_dim` is the number of time components. For example, the date 10.02.1993 has 3 components: year, month and year respectively
- `embed_dim_target` is the hidden dimension of the embedding
- `seq_len` is the length of the input sequence
- `BATCH_SIZE` is the size of the minibach

Embedding mapping; input (BATCH_SIZE, seq_len, **input_dim**) -> output (BATCH_SIZE, seq_len, **embed_dim**$\times$**input_dim**)

In [14]:
input_dim = 3
embed_dim_target = 12
seq_len = 5
BATCH_SIZE = 1

x = torch.rand(BATCH_SIZE,seq_len,input_dim)
print(f"X shape: {x.shape}")
x

X shape: torch.Size([1, 5, 3])


tensor([[[0.4980, 0.2627, 0.5106],
         [0.1706, 0.3551, 0.4692],
         [0.9027, 0.5891, 0.4011],
         [0.7878, 0.1128, 0.7704],
         [0.6210, 0.8645, 0.9596]]])

# Time2Vec

#### 0. Latent space subdivision
In order to have the specified `embed_dim_target`, we need to spread it onto the time components (`input_dim`) before proceeding.

In [5]:
assert embed_dim_target % input_dim == 0
embed_dim = embed_dim_target // input_dim

In [6]:
act_function = torch.sin
embed_weight = torch.rand(input_dim,embed_dim)
embed_bias = torch.rand(input_dim,embed_dim)

#### 1. Expand the time components on diagonal matrices

If we have 3 time components for each point in the sequence, they will be placed on the diagonal of a 3x3 square matrix. 

In [7]:
x_diag = torch.diag_embed(x)
print(f"X shape after diag_embed: {x_diag.shape}")
x_diag

X shape after diag_embed: torch.Size([1, 5, 3, 3])


tensor([[[[5.5029e-01, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 5.5779e-01, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 7.4214e-04]],

         [[7.7242e-01, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 1.7727e-01, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 1.5817e-01]],

         [[3.2493e-02, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 9.2086e-01, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 2.0280e-02]],

         [[8.9422e-01, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 1.0928e-02, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 5.8273e-02]],

         [[4.9967e-01, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 8.7002e-02, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 9.2712e-01]]]])

#### 2. Multiply the square diagonal matrices with the weights and add the bias
This step changes the dimension of the temporal components `input_dim` (3) --> `embed_dim` (12)

In [8]:
# x.shape = (bs, sequence_length, input_dim, input_dim)
x_affine = torch.matmul(x_diag, embed_weight) + embed_bias
print(f"Moltiplication with the weights of shape: {embed_weight.shape}")
print(f"X shape after matmul: {x_affine.shape}")
x_affine

Moltiplication with the weights of shape: torch.Size([3, 4])
X shape after matmul: torch.Size([1, 5, 3, 4])


tensor([[[[1.1625, 0.7545, 0.8021, 0.9528],
          [0.7828, 0.2907, 0.7901, 1.1776],
          [0.1552, 0.0167, 0.4054, 0.9135]],

         [[1.2985, 0.8171, 0.9025, 0.9880],
          [0.7305, 0.1645, 0.5024, 0.9820],
          [0.2756, 0.1611, 0.4209, 0.9656]],

         [[0.8454, 0.6087, 0.5682, 0.8708],
          [0.8327, 0.4111, 1.0647, 1.3643],
          [0.1701, 0.0346, 0.4073, 0.9199]],

         [[1.3730, 0.8514, 0.9575, 1.0073],
          [0.7077, 0.1093, 0.3766, 0.8965],
          [0.1992, 0.0695, 0.4110, 0.9325]],

         [[1.1315, 0.7403, 0.7793, 0.9448],
          [0.7181, 0.1345, 0.4341, 0.9356],
          [0.8637, 0.8665, 0.4966, 1.2204]]]])

#### 3. On the last dimension, we split between the zero component and the higher order ones, and apply the activation function to the latter. 

In [9]:
x_affine_0, x_affine_remain = torch.split(x_affine, [1, embed_dim - 1], dim=-1)
x_affine_remain = act_function(x_affine_remain)
print(f"Zero component shape: {x_affine_0.shape}")
print(f"Higher order components shape: {x_affine_remain.shape}")

Zero component shape: torch.Size([1, 5, 3, 1])
Higher order components shape: torch.Size([1, 5, 3, 3])


#### 4. Recompose the zero and higher order components

In [10]:
x_join = torch.cat([x_affine_0, x_affine_remain], dim=-1)
print(f"X shape: {x_join.shape}")

X shape: torch.Size([1, 5, 3, 4])


#### 5. Reshape the last component to flatten the [``input_dim``, ``embed_dim``] matrix into a [``input_dim``$\times$``embed_dim``]
Note that ``input_dim``$\times$``embed_dim`` = ``embed_dim_target``

In [11]:
x_output = x_join.view(x_join.size(0), x_join.size(1), -1)
print(f"X shape: {x_output.shape}")

X shape: torch.Size([1, 5, 12])


To see how this view operation looks like, take a look at the example below.

In [12]:
a = torch.Tensor([[1,1,1],[2,2,2],[3,3,3]])
a

tensor([[1., 1., 1.],
        [2., 2., 2.],
        [3., 3., 3.]])

In [13]:
a.view(1,a.size(0)*a.size(1))

tensor([[1., 1., 1., 2., 2., 2., 3., 3., 3.]])