# Define RNN Cell

RNN: Giả sử sample X = [3, 3] => Mỗi x_t = [1, 3]
1. Tính output của lớp hidden layer (có bias), với d_h = 2.
2. Tính output của model ở lớp cuối cùng, với d_out = 1. (regression).
3. Chuyển bài toán thành bài toán Sequence Labeling, tính output của từng time-steps.
<!-- 4. Bidirectional RNN, đưa ra output cuối cùng. -->

LSTM: Không làm bias, quá phức tạp.
1. Tính forget gate.
2. Tính add gate.
3. Tính output gate.
4. Tính final output.
<!-- 4. Stacked LSTM với hàm kích hoạt RELU. -->

In [1]:
import torch
import numpy as np
import torch.nn as nn


class RNNCell(nn.Module):
    def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh"):
        super(RNNCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias

        if nonlinearity not in ["tanh", "relu"]:
            raise ValueError("Invalid nonlinearity. Choose 'tanh' or 'relu'.")

        self.nonlinearity = nonlinearity

        # Linear transformations
        self.input_to_hidden = nn.Linear(input_size, hidden_size, bias=bias)
        self.hidden_to_hidden = nn.Linear(hidden_size, hidden_size, bias=bias)

    def set_weights(
        self,
        input_to_hidden_weight,
        hidden_to_hidden_weight,
        input_to_hidden_bias,
        hidden_to_hidden_bias,
    ):
        self.input_to_hidden.weight.data = input_to_hidden_weight
        self.hidden_to_hidden.weight.data = hidden_to_hidden_weight
        self.input_to_hidden.bias.data = input_to_hidden_bias
        self.hidden_to_hidden.bias.data = hidden_to_hidden_bias

    def forward(self, input, hidden_state_input=None):
        if hidden_state_input is None:
            hidden_state_input = input.new_zeros(
                input.size(0), self.hidden_size, requires_grad=False
            )

        # Compute the new hidden state
        hidden_state = self.input_to_hidden(input) + self.hidden_to_hidden(
            hidden_state_input
        )
        hidden_state = (
            torch.tanh(hidden_state)
            if self.nonlinearity == "tanh"
            else torch.relu(hidden_state)
        )

        return hidden_state

In [None]:
#  1. Tính output của lớp hidden layer (có bias), với d_h = 2.
X = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])

torch.manual_seed(42)
rnn = RNNCell(input_size=3, hidden_size=2)

input_to_hidden_weight = torch.tensor([[-1.0, 0.0, 0.5], [1.0, 0.0, -0.5]])
input_to_hidden_bias = torch.tensor([1.0])
hidden_to_hidden_weight = torch.tensor([[0.0, -1.0]])
hidden_to_hidden_bias = torch.tensor([1.0])

rnn.set_weights(
    input_to_hidden_weight,
    hidden_to_hidden_weight,
    input_to_hidden_bias,
    hidden_to_hidden_bias,
)

with torch.no_grad():
    # Output of the hidden state
    hidden_state = None
    for i in range(3):
        hidden_state = rnn(X[i].unsqueeze(0), hidden_state)
        print(
            "Output of the hidden state at time step {}: \n{}".format(i, hidden_state)
        )

Output of the hidden state at time step 0: 
tensor([[0.9866, 0.9051]])
Output of the hidden state at time step 1: 
tensor([[0.0946, 0.9702]])
Output of the hidden state at time step 2: 
tensor([[-0.8996,  0.9983]])


In [3]:
Wxh = torch.tensor([[-1.0, 0.0, 0.5], [1.0, 0.0, -0.5]])
bxh = torch.tensor([1.0])
Whh = torch.tensor([[0.0, -1.0], [0.0, -1.0]])
bhh = torch.tensor([1.0])

x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
h = torch.tensor([[0.0, 0.0]])
for i in range(3):
    # Compute the new hidden state using the RNN cell formula
    h = torch.tanh(x[i] @ Wxh.T + h @ Whh.T + bxh + bhh)
    print(h)

tensor([[0.9866, 0.9051]])
tensor([[0.0946, 0.9702]])
tensor([[-0.8996,  0.9983]])


In [24]:
rnn = torch.nn.RNN(
    input_size=3,
    hidden_size=2,
    bias=True,
    nonlinearity="tanh",
)

rnn.weight_ih_l0 = torch.nn.Parameter(Wxh)
rnn.weight_hh_l0 = torch.nn.Parameter(Whh)
rnn.bias_ih_l0 = torch.nn.Parameter(bxh)
rnn.bias_hh_l0 = torch.nn.Parameter(bhh)

x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
output, hn = rnn(x)
output, output.shape, hn, hn.shape

(tensor([[ 0.9866,  0.9051],
         [ 0.0946,  0.9702],
         [-0.8996,  0.9983]], grad_fn=<SqueezeBackward1>),
 torch.Size([3, 2]),
 tensor([[-0.8996,  0.9983]], grad_fn=<SqueezeBackward1>),
 torch.Size([1, 2]))

In [123]:
# 2. Tính output của model ở lớp cuối cùng, với d_out = 1. (regression).
hidden_state = torch.tensor([[-0.8996, 0.9983]])
print("Final hidden state: \n{}".format(hidden_state))

hidden_to_output = nn.Linear(2, 1, bias=True)
hidden_to_output.weight.data = torch.tensor([[0.0, 1.0]])
hidden_to_output.bias.data = torch.tensor([0.0])

with torch.no_grad():
    output = hidden_to_output(hidden_state)
    print("Output: \n{}".format(output))

Final hidden state: 
tensor([[-0.8996,  0.9983]])
Output: 
tensor([[0.9983]])


In [98]:
hn.squeeze()

tensor([-0.8996,  0.9983], grad_fn=<SqueezeBackward0>)

In [124]:
W = torch.tensor([0.0, 1.0])
b = torch.tensor(0.0)
W @ hn.squeeze() + b

tensor(0.9983, grad_fn=<AddBackward0>)

In [None]:
# 3. Chuyển bài toán thành bài toán Sequence Labeling, tính output của từng time-steps.
# Giả sử mỗi hàng là một embedding của một từ trong câu, và mỗi từ được gán nhãn 0 hoặc 1.

hidden_to_output = nn.Linear(2, 2, bias=True)
hidden_to_output.weight.data = torch.tensor([[0.0, 1.0], [1.0, 0.0]])
hidden_to_output.bias.data = torch.tensor([0.0])

hidden_state = None
with torch.no_grad():
    for i in range(3):
        print(25 * "*")
        hidden_state = rnn(X[i].unsqueeze(0), hidden_state)
        print(
            "Output of the hidden state at time step {}: \n{}".format(i, hidden_state)
        )
        output = nn.Softmax(dim=1)(hidden_to_output(hidden_state))
        print("Output at time step {}: \n{}".format(i, output))
        print(
            "Predicted label at time step {}: \n{}".format(
                i, torch.argmax(output, dim=1)
            )
        )

*************************
Output of the hidden state at time step 0: 
tensor([[0.9866, 0.9051]])
Output at time step 0: 
tensor([[0.4796, 0.5204]])
Predicted label at time step 0: 
tensor([1])
*************************
Output of the hidden state at time step 1: 
tensor([[0.0946, 0.9702]])
Output at time step 1: 
tensor([[0.7059, 0.2941]])
Predicted label at time step 1: 
tensor([0])
*************************
Output of the hidden state at time step 2: 
tensor([[-0.8996,  0.9983]])
Output at time step 2: 
tensor([[0.8697, 0.1303]])
Predicted label at time step 2: 
tensor([0])


In [150]:
W.shape, output.shape

(torch.Size([2, 2]), torch.Size([3, 2]))

In [5]:
W = torch.tensor([[0.0, 1.0], [1.0, 0.0]])
b = torch.tensor([0.0, 0.0])
W @ output.transpose(0, 1)

tensor([[ 0.9051,  0.9702,  0.9983],
        [ 0.9866,  0.0946, -0.8996]], grad_fn=<MmBackward0>)

In [None]:
# 4. Bidirectional RNN, đưa ra output cuối cùng.
# Giá trị trọng số 2 mạng là như nhau, tính output của model ở lớp cuối cùng.
forward_rnn = RNNCell(input_size=3, hidden_size=2)
backward_rnn = RNNCell(input_size=3, hidden_size=2)

forward_rnn.set_weights(
    input_to_hidden_weight,
    hidden_to_hidden_weight,
    input_to_hidden_bias,
    hidden_to_hidden_bias,
)
backward_rnn.set_weights(
    input_to_hidden_weight,
    hidden_to_hidden_weight,
    input_to_hidden_bias,
    hidden_to_hidden_bias,
)

hidden_to_output = nn.Linear(4, 1, bias=True)
hidden_to_output.weight.data = torch.tensor([[1.0, 1.0, 1.0, 1.0]])
hidden_to_output.bias.data = torch.tensor([1.0])

hidden_state_forward = None
hidden_state_backward = None
with torch.no_grad():
    for i in range(3):
        print(25 * "*")
        hidden_state_forward = forward_rnn(X[i].unsqueeze(0), hidden_state_forward)
        hidden_state_backward = backward_rnn(
            X[2 - i].unsqueeze(0), hidden_state_backward
        )
        print(
            "Output of the hidden state at time step {}: \n{}".format(
                i, hidden_state_forward
            )
        )
        print(
            "Output of the hidden state at time step {}: \n{}".format(
                i, hidden_state_backward
            )
        )
        hidden_state = torch.cat((hidden_state_forward, hidden_state_backward), dim=1)
        output = hidden_to_output(hidden_state)
        print("Output at time step {}: \n{}".format(i, output))

*************************
Output of the hidden state at time step 0: 
tensor([[0.9866, 0.9051]])
Output of the hidden state at time step 0: 
tensor([[-0.4621,  0.9998]])
Output at time step 0: 
tensor([[3.4294]])
*************************
Output of the hidden state at time step 1: 
tensor([[0.0946, 0.9702]])
Output of the hidden state at time step 1: 
tensor([[2.4676e-04, 9.6404e-01]])
Output at time step 1: 
tensor([[3.0290]])
*************************
Output of the hidden state at time step 2: 
tensor([[-0.8996,  0.9983]])
Output of the hidden state at time step 2: 
tensor([[0.9114, 0.4899]])
Output at time step 2: 
tensor([[2.5000]])


In [8]:
hn.squeeze()

tensor([-0.8996,  0.9983], grad_fn=<SqueezeBackward0>)

In [23]:
hn_b = torch.tensor([[0.9114, 0.4899]])
hn_cat = torch.cat((hn.squeeze(), hn_b.squeeze()), dim=0)
torch.ones((1, 4)) @ hn_cat + torch.tensor(1.0)

tensor([2.5000], grad_fn=<AddBackward0>)

# Define LSTM Cell

In [None]:
class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size, bias=False):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias

        # Linear layers for input-to-hidden and hidden-to-hidden transformations
        self.hidden_to_forget = nn.Linear(hidden_size, hidden_size, bias=bias)
        self.hidden_to_input = nn.Linear(hidden_size, hidden_size, bias=bias)
        self.hidden_to_cell = nn.Linear(hidden_size, hidden_size, bias=bias)
        self.hidden_to_output = nn.Linear(hidden_size, hidden_size, bias=bias)

        self.input_to_forget = nn.Linear(input_size, hidden_size, bias=bias)
        self.input_to_input = nn.Linear(input_size, hidden_size, bias=bias)
        self.input_to_cell = nn.Linear(input_size, hidden_size, bias=bias)
        self.input_to_output = nn.Linear(input_size, hidden_size, bias=bias)

    def set_weights(
        self,
        input_to_forget_weight,
        input_to_input_weight,
        input_to_cell_weight,
        input_to_output_weight,
        hidden_to_forget_weight,
        hidden_to_input_weight,
        hidden_to_cell_weight,
        hidden_to_output_weight,
    ):
        self.input_to_forget.weight.data = input_to_forget_weight
        self.input_to_input.weight.data = input_to_input_weight
        self.input_to_cell.weight.data = input_to_cell_weight
        self.input_to_output.weight.data = input_to_output_weight

        self.hidden_to_forget.weight.data = hidden_to_forget_weight
        self.hidden_to_input.weight.data = hidden_to_input_weight
        self.hidden_to_cell.weight.data = hidden_to_cell_weight
        self.hidden_to_output.weight.data = hidden_to_output_weight

    def forward(self, input, hidden_state_tuple=None):
        # If hidden state is not provided, initialize it to zeros (the first time step)
        if hidden_state_tuple is None:
            hidden_state_tuple = input.new_zeros(
                input.size(0), self.hidden_size, requires_grad=False
            )
            hidden_state_tuple = (hidden_state_tuple, hidden_state_tuple)

        hidden_state, cell_state_prev = hidden_state_tuple

        # Compute gates
        input_gate = self.input_to_input(input) + self.hidden_to_input(
            hidden_state
        )  # Add gate
        forget_gate = self.input_to_forget(input) + self.hidden_to_forget(
            hidden_state
        )  # Forget gate
        cell_gate = self.input_to_cell(input) + self.hidden_to_cell(hidden_state)
        output_gate = self.input_to_output(input) + self.hidden_to_output(
            hidden_state
        )  # Output gate

        print("Input gate: \n{}".format(input_gate))
        print("Forget gate: \n{}".format(forget_gate))
        print("Cell gate: \n{}".format(cell_gate))

        # Apply nonlinearities
        i_t = torch.sigmoid(input_gate)
        f_t = torch.sigmoid(forget_gate)
        g_t = torch.tanh(cell_gate)
        o_t = torch.sigmoid(output_gate)

        # Update cell state and hidden state
        # k_t = f_t * cell_state_prev
        # j_t = i_t * g_t
        cell_state_next = f_t * cell_state_prev + i_t * g_t
        hidden_state_next = o_t * torch.tanh(cell_state_next)

        return hidden_state_next, cell_state_next

In [None]:
lstm = LSTMCell(input_size=3, hidden_size=2)
classification = nn.Linear(2, 1, bias=True)

input_to_forget_weight = torch.tensor([[1.0, 0.0, 0.5], [1.0, 0.0, -0.5]])
input_to_input_weight = torch.tensor([[-1.0, 0.0, 0.5], [-1.0, 0.0, -0.5]])
input_to_cell_weight = torch.tensor([[1.0, 0.0, -0.5], [1.0, 0.0, -0.5]])
input_to_output_weight = torch.tensor([[1.0, 1.0, 0.5], [1.0, -1.0, -0.5]])

hidden_to_forget_weight = torch.tensor([[0.0, 1.0], [0.0, -1.0]])
hidden_to_input_weight = torch.tensor([[0.0, 1.0], [0.0, -1.0]])
hidden_to_cell_weight = torch.tensor([[0.0, 1.0], [0.0, -1.0]])
hidden_to_output_weight = torch.tensor([[0.0, 1.0], [0.0, -1.0]])

lstm.set_weights(
    input_to_forget_weight,
    input_to_input_weight,
    input_to_cell_weight,
    input_to_output_weight,
    hidden_to_forget_weight,
    hidden_to_input_weight,
    hidden_to_cell_weight,
    hidden_to_output_weight,
)

hidden_state_tuple = None
with torch.no_grad():
    print(25 * "#")
    hidden_state_tuple = lstm(X[0].unsqueeze(0), hidden_state_tuple)
    print(
        "Output of the hidden state at time step 0: \n{}".format(hidden_state_tuple[0])
    )

#########################
Input gate: 
tensor([[ 0.5000, -2.5000]])
Forget gate: 
tensor([[ 2.5000, -0.5000]])
Cell gate: 
tensor([[-0.5000, -0.5000]])
Output of the hidden state at time step 0: 
tensor([[-0.2769, -0.0027]])


In [36]:
import torch

lstm = torch.nn.LSTM(input_size=3, hidden_size=2, bias=False)

for k, v in lstm.state_dict().items():
    print(k, v.shape)

weight_ih_l0 torch.Size([8, 3])
weight_hh_l0 torch.Size([8, 2])


In [56]:
Wxf = torch.tensor([[1.0, 0.0, 0.5], [1.0, 0.0, -0.5]])
Whf = torch.tensor([[0.0, 1.0], [0.0, -1.0]])
Wxi = torch.tensor([[-1.0, 0.0, 0.5], [-1.0, 0.0, -0.5]])
Whi = torch.tensor([[0.0, 1.0], [0.0, -1.0]])
Wxc = torch.tensor([[1.0, 0.0, -0.5], [1.0, 0.0, -0.5]])
Whc = torch.tensor([[0.0, 1.0], [0.0, -1.0]])
Wxo = torch.tensor([[1.0, 1.0, 0.5], [1.0, -1.0, -0.5]])
Who = torch.tensor([[0.0, 1.0], [0.0, -1.0]])

C_t_minus_1 = torch.tensor([[0.0, 0.0]])
x_t = torch.tensor([[1.0, 2.0, 3.0]])

In [57]:
(f_t := x_t @ Wxf.T + C_t_minus_1 @ Whf.T)
(i_t := x_t @ Wxi.T + C_t_minus_1 @ Whi.T)
(g_t := x_t @ Wxc.T + C_t_minus_1 @ Whc.T)
(o_t := x_t @ Wxo.T + C_t_minus_1 @ Who.T)

tensor([[ 4.5000, -2.5000]])

In [58]:
f_t, i_t, g_t, o_t

(tensor([[ 2.5000, -0.5000]]),
 tensor([[ 0.5000, -2.5000]]),
 tensor([[-0.5000, -0.5000]]),
 tensor([[ 4.5000, -2.5000]]))

In [59]:
f_t = torch.sigmoid(f_t)
i_t = torch.sigmoid(i_t)
g_t = torch.tanh(g_t)
o_t = torch.sigmoid(o_t)
c_t = f_t * C_t_minus_1 + i_t * g_t
h_t = o_t * torch.tanh(c_t)

In [None]:
h_t

tensor([[-0.2769, -0.0027]])

: 