In [4]:
import math
import string

import torch
import torch.nn as nn
import torch.nn.functional as F

In [22]:
class MyRNN(nn.Module):
    def __init__(self, n_in: int, n_h: int) -> None:
        super(MyRNN, self).__init__()

        a, b = 1/math.sqrt(n_in), 1/math.sqrt(n_h)

        self.W_ih: torch.Tensor = -a + 2*a * torch.rand(n_in, n_h)
        self.b_ih: torch.Tensor = torch.zeros(n_h)

        self.W_hh: torch.Tensor = -b + 2*b * torch.rand(n_h, n_h)
        self.b_hh: torch.Tensor = torch.zeros(n_h)

    def forward(self, sequence: torch.Tensor) -> tuple[torch.Tensor]:
        if sequence.ndim == 2:
            sequence = sequence.unsqueeze(0)
        sequence = sequence.transpose(0, 1)

        states = []
        context = torch.zeros(sequence.shape[1], len(self.b_hh))

        for x in sequence:
            context = F.tanh((x @ self.W_ih) + self.b_ih + (context @ self.W_hh) + self.b_hh)
            states.append(context)

        states = torch.stack(states).transpose(0, 1)
        return (states, context)

In [23]:
abc = ' .,:;-\"\'' + string.ascii_letters
len_abc = len(abc)

ctoi = {c:i for i,c in enumerate(abc)}
itoc = {i:c for i,c in enumerate(abc)}

In [28]:
def encode(s: str) -> list[int]:
    return [ctoi[c] for c in s]

def decode(l: list[int]) -> str:
    return ''.join([itoc[i] for i in l])

def str2seq(s: str) -> torch.Tensor:
    encoded = torch.tensor(encode(s)).long()
    return F.one_hot(encoded, len_abc).float()

In [29]:
rnn = MyRNN(len_abc, 4)

In [30]:
text = 'the quick brown fox jumps over the lazy dog'

states, context = rnn(str2seq(text))
print(states.shape, context.shape)

torch.Size([1, 43, 4]) torch.Size([1, 4])


In [31]:
text1 = 'the quick brown fox jumps over the lazy dog'
text2 = 'god yzal eht revo spmuj xof nworb kciuq eht'
batch = torch.stack([str2seq(text1), str2seq(text2)])

states, context = rnn(batch)
print(states.shape, context.shape)

torch.Size([2, 43, 4]) torch.Size([2, 4])
