In [1]:
import os
os.chdir("/Users/yenchenchou/Documents/GitHub/ml-learning")

In [17]:
import torch
import torch.nn as nn
import numpy as np
import random

In [18]:
class EnvInit:
    def available_device(self) -> torch.device:
        if torch.backends.mps.is_available():
            device = torch.device("mps")
        elif torch.cuda.is_available():
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")
        return device

    def fix_seed(self, seed: int) -> int:
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
        if torch.backends.mps.is_available():
            torch.mps.manual_seed(seed)
        return seed

In [96]:
import torch
import torch.nn as nn

class MytorchLSTMV1(nn.Module):
    """
    Custom LSTM implementation that follows PyTorch's API.
    This LSTM cell includes three main gates: forget, input, and output gates.

    (1) Forget gate: Decides what to forget from the cell state.
        ft = sigmoid(Ht-1 @ Whf + X @ Wif + bhf) -> prepare what to keep/forget
        ft * ct-1 -> actual forgetting

    (2) Input gate: Decides what new information to store in the cell state.
        it = sigmoid(Ht-1 @ Whi + X @ Wii + bhi) -> prepare what to add
        ct_candidate = tanh(Ht-1 @ Whc + X @ Wic + bhc) -> prepare candidate information
        it * ct_candidate -> actual adding

        ct = ft * ct-1 + it * ct_candidate -> combining forget and input operations to update cell state

    (3) Output gate: Decides what information to output based on the new cell state.
        ot = sigmoid(Ht-1 @ Who + X @ Wio + bho)
        ht = ot * tanh(ct) -> applying non-linearity to the cell state and filtering to produce the new hidden state

    The implementation uses the following notation:
        - Ht-1: Previous hidden state
        - X: Current input
        - ct-1: Previous cell state
        - Wif, Wii, Wic, Wio: Weight matrices for input to gates
        - Whf, Whi, Whc, Who: Weight matrices for hidden state to gates
        - bif, bii, bic, bio: Bias vectors for input to gates
        - bhf, bhi, bhc, bho: Bias vectors for hidden state to gates
    """

    def __init__(
        self, input_size: int, hidden_size: int, batch_first: bool = False, bias: bool = True
    ):
        super(MytorchLSTMV1, self).__init__()
        self.batch_first = batch_first
        self.input_size = input_size
        self.hidden_size = hidden_size

        # Forget gate
        self.wif = nn.Linear(input_size, hidden_size, bias=bias)
        self.whf = nn.Linear(hidden_size, hidden_size, bias=bias)

        # Input gate
        self.wii = nn.Linear(input_size, hidden_size, bias=bias)
        self.whi = nn.Linear(hidden_size, hidden_size, bias=bias)

        # Candidate gate
        self.wic = nn.Linear(input_size, hidden_size, bias=bias)
        self.whc = nn.Linear(hidden_size, hidden_size, bias=bias)

        # Output gate
        self.wio = nn.Linear(input_size, hidden_size, bias=bias)
        self.who = nn.Linear(hidden_size, hidden_size, bias=bias)

    def forward(
        self, X: torch.Tensor, hidden_tuple: tuple[torch.Tensor, torch.Tensor]
    ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        if self.batch_first:
            X = X.transpose(0, 1)

        h_prev, c_prev = hidden_tuple
        outputs = []

        for t in range(X.size(0)):
            x_t = X[t]
            ft = torch.sigmoid(self.wif(x_t) + self.whf(h_prev))
            it = torch.sigmoid(self.wii(x_t) + self.whi(h_prev))
            candidate = torch.tanh(self.wic(x_t) + self.whc(h_prev))
            c_prev = ft * c_prev + it * candidate
            ot = torch.sigmoid(self.wio(x_t) + self.who(h_prev))
            h_prev = ot * torch.tanh(c_prev)
            outputs.append(h_prev)

        outputs = torch.cat(outputs)
        if self.batch_first:
            outputs = outputs.transpose(0, 1)

        return outputs, (h_prev, c_prev)


In [103]:
input_size = 3
hidden_size = 6
batch_size = 2
seq_length = 8
lstm = MytorchLSTMV1(input_size, hidden_size, batch_first=True, bias=True)  # input_size, hidden_size
x = torch.randn(batch_size, seq_length, input_size)  # batch_size, seq_length, input_size
h0 = torch.randn(batch_size, hidden_size).unsqueeze(0)  # batch_size, hidden_size -> 1(one_directional), batch_size, hidden_size
c0 = torch.randn(batch_size, hidden_size).unsqueeze(0)  # batch_size, hidden_size -> 1(one_directional), batch_size, hidden_size
output, (ht, ct) = lstm(x, (h0, c0))

In [104]:
output, output.size()

(tensor([[[-0.0486, -0.4845, -0.1612,  0.0953, -0.4194, -0.4590],
          [ 0.0302, -0.2870, -0.0870,  0.0015, -0.4005, -0.0403],
          [ 0.1063, -0.1486, -0.0224, -0.0485, -0.2485,  0.0127],
          [ 0.1612, -0.1206, -0.2458, -0.0148, -0.2493,  0.1344],
          [ 0.1405, -0.2365, -0.1662, -0.1215,  0.0457,  0.0251],
          [ 0.2077, -0.1192, -0.1891, -0.0703, -0.0832,  0.0990],
          [ 0.2366, -0.0976,  0.0934, -0.2086,  0.0132,  0.0281],
          [ 0.1891, -0.1331,  0.1026, -0.1295,  0.0250,  0.0550]],
 
         [[-0.3546, -0.4126, -0.2590, -0.2608,  0.0372, -0.3793],
          [-0.0580, -0.0429, -0.3364, -0.0255,  0.0847,  0.0588],
          [ 0.0006, -0.0506, -0.2623, -0.0594, -0.0038,  0.1337],
          [ 0.0991, -0.1199, -0.0696, -0.1217,  0.0202,  0.1158],
          [ 0.1514, -0.0383, -0.1470, -0.0240, -0.2287,  0.1310],
          [ 0.2270, -0.1361, -0.0942, -0.0718,  0.0182,  0.1390],
          [ 0.1693, -0.1397, -0.2141, -0.0412, -0.0319,  0.1919],
       

In [105]:
ht, ht.size()

(tensor([[[ 0.1891, -0.1331,  0.1026, -0.1295,  0.0250,  0.0550],
          [ 0.1453, -0.1554, -0.2053, -0.0744,  0.0076,  0.1989]]],
        grad_fn=<MulBackward0>),
 torch.Size([1, 2, 6]))

In [106]:
ct, ct.size()

(tensor([[[ 0.6705, -0.2353,  0.1591, -0.2561,  0.0606,  0.1123],
          [ 0.5622, -0.2613, -0.3081, -0.1608,  0.0177,  0.3928]]],
        grad_fn=<AddBackward0>),
 torch.Size([1, 2, 6]))

In [77]:
rnn = nn.LSTM(10, 20, 1, batch_first=True)
input = torch.randn(3, 5, 10)
h0 = torch.randn(3, 20).unsqueeze(0)
c0 = torch.randn(3, 20).unsqueeze(0)
output, (hn, cn) = rnn(input, (h0, c0))

In [107]:
output.size()

torch.Size([2, 8, 6])

In [108]:
hn.size()

torch.Size([1, 3, 20])

In [109]:
cn.size()

torch.Size([1, 3, 20])