In [89]:
import os
os.chdir("/Users/yenchenchou/Documents/GitHub/ml-learning")

In [90]:
import torch
import torch.nn as nn
import numpy as np
import random

In [91]:
class EnvInit:
    def available_device(self) -> torch.device:
        if torch.backends.mps.is_available():
            device = torch.device("mps")
        elif torch.cuda.is_available():
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")
        return device

    def fix_seed(self, seed: int) -> int:
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
        if torch.backends.mps.is_available():
            torch.mps.manual_seed(seed)
        return seed

In [92]:
class MyRNNCell(nn.Module):
    """
    Custom RNN implementation that follows PyTorch's API.

    (1) ht = tanh(Ht-1 @ Whh + bhh + X @ Wih + bih)

    The implementation uses the following notation:
        - Ht-1: Previous hidden state
        - X: Current input
        - Whh: Weight for hidden state
        - Wih: Weight for input state
    """

    def __init__(
        self, input_size: int, hidden_size: int, batch_first: bool, bias: bool
    ):
        super(MyRNNCell, self).__init__()
        self.batch_first = batch_first
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.w_ih = nn.Linear(input_size, hidden_size, bias)
        self.w_hh = nn.Linear(hidden_size, hidden_size, bias)

    def forward(self, X: torch.tensor, ho=None):
        if self.batch_first:
            X = X.transpose(0, 1)
        if ho is None:
            batch_size = X.size(1)
            ho = torch.zeros(batch_size, self.hidden_size).unsqueeze(0)
        h_t = self.w_hh(ho)
        output = torch.tanh(self.w_ih(X) + h_t)
        if self.batch_first:
            output = output.transpose(0, 1)
        return output, h_t

In [93]:
env_init = EnvInit()
seed = env_init.fix_seed(12345)

# Initialize models
x = torch.ones(2, 3, 2)  # batch, seq_length, input_size
h_0 = torch.zeros(2, 4).unsqueeze(0)  # batch_size, hidden_size -> 1(one_directional), batch_size, hidden_size
input_size = x.size(-1)
hidden_size = h_0.size(-1)
rnn_v1 = MyTorchRNN(input_size, hidden_size, batch_first=True)
output, ht = rnn_v1(x, h_0)

In [97]:
output, output.size()

(tensor([[[ 0.8391,  0.5655, -0.7795, -0.0916],
          [ 0.8391,  0.5655, -0.7795, -0.0916],
          [ 0.8391,  0.5655, -0.7795, -0.0916]],
 
         [[ 0.8391,  0.5655, -0.7795, -0.0916],
          [ 0.8391,  0.5655, -0.7795, -0.0916],
          [ 0.8391,  0.5655, -0.7795, -0.0916]]], grad_fn=<TransposeBackward0>),
 torch.Size([2, 3, 4]))