# Setup
$\def\underline v{\underline}$
$\def\underline w{\underline w}$
$\def\underline x{\underline x}$
$\def\f{\underline f}$
$\def\y{\underline y}$
$\def\b{\underline b}$

# nn.Embedding

In [1]:
import torch
from torch import nn

# (batch_size, doc_len)
input = torch.tensor([
    [2, 4, 3, 1],
    [2, 0, 1, 1]
])
emb_fn = nn.Embedding(10, 3, padding_idx=1)
emb_fn = nn.Embedding.from_pretrained(
    torch.tensor([[0, 0, 0, 0],
                  [1, 1, 1, 1],
                  [2, 2, 2, 2],
                  [3, 3, 3, 3],
                  [4, 4, 4, 4]]).float()
)
# (batch_size, doc_len, emb_dim)
emb = emb_fn(input)
print(emb)

tensor([[[2., 2., 2., 2.],
         [4., 4., 4., 4.],
         [3., 3., 3., 3.],
         [1., 1., 1., 1.]],

        [[2., 2., 2., 2.],
         [0., 0., 0., 0.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]])


# Linear

Say we have a linear layer with 2 inputs and 3 outputs.

The first output is computed as $h_0 = x_0 w_{00} + x_1 w_{01} = \underline x \underline w_0^T$.

Let us stack all weights of the layer in a matrix
$$
W = \begin{bmatrix} 
  - \underline w_0 - \\ 
  - \underline w_1 - \\ 
  - \underline w_2 - 
\end{bmatrix}
$$
of dimension (output dim, input dim). Each row $i$ has the weights of neuron $i$ that outputs $h_i$. Then $\underline v h = \underline x W^T$. Here, following PyTorch convention, both $\underline v h$ and $\underline x$ are column vectors. For batched inputs, $H = X W^T$.

$$
\begin{bmatrix} 
  x^{(0)}_0 & x^{(0)}_1 \\ 
  x^{(1)}_0 & x^{(1)}_1
\end{bmatrix}
\begin{bmatrix} 
  w_{00} & w_{10} & w_{20} \\
  w_{01} & w_{11} & w_{21}
\end{bmatrix}
=
\begin{bmatrix} 
  h^{(0)}_0 & h^{(0)}_1 & h^{(0)}_2 \\ 
  h^{(1)}_0 & h^{(1)}_1 & h^{(1)}_2
\end{bmatrix}
$$

<img src="img/impl-rnn-linear.drawio.svg" width="550"/>

In [2]:
emb_dim = 2
hid_dim = 3

# (batch_size  =2, doc_len=3, emb_dim=2)
input = torch.tensor(
    [[[1, 1],
      [2, 2],
      [3, 3]],
      
     [[1, 1],
      [2, 2],
      [3, 3]]],
).float()

linear_fn = nn.Linear(emb_dim, hid_dim)
# This should have shape (out_features, in_features)
linear_fn.weight = nn.Parameter(torch.tensor([[0, 0],
                                              [0, 1],
                                              [1, 0]]).float())
# This should have shape (out_features,)
linear_fn.bias = nn.Parameter(torch.tensor([1, 1, 1]).float())

# (batch_size, doc_len, hid_dim)
linear_fn(input)

tensor([[[1., 2., 2.],
         [1., 3., 3.],
         [1., 4., 4.]],

        [[1., 2., 2.],
         [1., 3., 3.],
         [1., 4., 4.]]], grad_fn=<AddBackward0>)

# RNN

Here is an illustration of the computational graph.

<img src="img/impl-rnn-rnn.drawio.svg" width="750"/>

And here is a higher level illustration (type 2 on the right below).

<img src="img/impl-rnn-rnn-2.jpg" width="650"/>


In [3]:
# (batch size, doc len, emb dim)
input = torch.tensor(
    [
        [[11, 11],
         [12, 12]],
      
        [[21, 21],
         [22, 22]],

        [[31, 31],
         [32, 32]]
     ],
).float()
print(input)

rnn_fn = nn.RNN(input_size=2, hidden_size=5, num_layers=1,
                batch_first=True)

out_across_time, hid_last_time = rnn_fn(input)

# (batch size, doc len, emb dim * num directions)
# (batch size, num layers * num directions, emb dim)

print(out_across_time.size())
print(out_across_time)

print(hid_last_time.size())
print(hid_last_time)

tensor([[[11., 11.],
         [12., 12.]],

        [[21., 21.],
         [22., 22.]],

        [[31., 31.],
         [32., 32.]]])
torch.Size([3, 2, 5])
tensor([[[-0.9819, -0.5371,  1.0000,  0.9950,  1.0000],
         [-0.9871, -0.7188,  1.0000,  0.9984,  1.0000]],

        [[-0.9998, -0.9091,  1.0000,  1.0000,  1.0000],
         [-0.9998, -0.9500,  1.0000,  1.0000,  1.0000]],

        [[-1.0000, -0.9851,  1.0000,  1.0000,  1.0000],
         [-1.0000, -0.9919,  1.0000,  1.0000,  1.0000]]],
       grad_fn=<TransposeBackward1>)
torch.Size([1, 3, 5])
tensor([[[-0.9871, -0.7188,  1.0000,  0.9984,  1.0000],
         [-0.9998, -0.9500,  1.0000,  1.0000,  1.0000],
         [-1.0000, -0.9919,  1.0000,  1.0000,  1.0000]]],
       grad_fn=<StackBackward0>)


In [5]:
# (batch size, doc len, emb dim)
input = torch.tensor(
    [
        [[11, 11],
         [12, 12]],

        [[21, 21],
         [22, 22]],

        [[31, 31],
         [1, 1]], # padding

        [[41, 41],
         [1, 1]] # padding
     ]
).float()

# same as input.transpose(0, 1)
input = torch.tensor( 
    [   # time step 1, i.e. embedding of first word for all documents in the batch
        [[11., 11.],
         [21., 21.],
         [31., 31.],
         [41., 41.]],

        # time step 2
        [[12., 12.],
         [22., 22.],
         [ 1.,  1.], # padding
         [ 1.,  1.]] # padding
    ]
)
# (batch size,)
dl = torch.tensor([2, 2, 1, 1])

# Sequences should be sorted in decreasing order
packed = nn.utils.rnn.pack_padded_sequence(input, dl)

# packed.data = torch.tensor(
#     [
#         # time step 1 across all documents
#         [11., 11.],
#         [21., 21.],
#         [31., 31.],
#         [41., 41.],
#         # Time step 2 across all documents that have such a time step
#         # Recall input is sorted buy length
#         [12., 12.],
#         [22., 22.]
#     ]
# ), of dim (sum of all doc lengths, emb dim).
#
# packed.batch_sizes = [4, 2].
# Position t stores the number of documents in the batch that still have words
# at time t.
print(packed)

rnn_fn = nn.RNN(input_size=2, hidden_size=3, num_layers=2, bidirectional=True)
# out_across_time.data: (sum of lens of all batch docs, emb dim * num directions)
# hid_last_time: (num layers * num directions, batch size, emb dim)
out_across_time, hid_last_time = rnn_fn(packed)

# (doc len, batch size, emb dim * num directions)
out_across_time, _ = nn.utils.rnn.pad_packed_sequence(
    out_across_time,
    padding_value=1.0
)

print(out_across_time.transpose(0, 1))

PackedSequence(data=tensor([[11., 11.],
        [21., 21.],
        [31., 31.],
        [41., 41.],
        [12., 12.],
        [22., 22.]]), batch_sizes=tensor([4, 2]), sorted_indices=None, unsorted_indices=None)
tensor([[[-0.1654, -0.0464, -0.0121, -0.6202, -0.2044,  0.3503],
         [-0.1004, -0.0021,  0.0830, -0.3628, -0.0523,  0.4666]],

        [[-0.2162, -0.0467, -0.0456, -0.6263, -0.2797,  0.3420],
         [-0.1428,  0.0275,  0.0870, -0.3663, -0.1060,  0.4790]],

        [[-0.1998, -0.0466, -0.0348, -0.3691, -0.1528,  0.4897],
         [ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000]],

        [[-0.2321, -0.0467, -0.0563, -0.3720, -0.2018,  0.5008],
         [ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000]]],
       grad_fn=<TransposeBackward0>)


# Sequence to sequence
<img src="img/impl-rnn-s2s.png" width="550" />