In [1]:
import string

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
abc = ' ' + string.punctuation + string.digits + string.ascii_letters
len_abc = len(abc)

itoc = {i:c for i,c in enumerate(abc)}
ctoi = {c:i for i,c in enumerate(abc)}

print(len_abc, abc)

95  !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ


In [4]:
def encode(s: str) -> list[int]:
    return [ctoi[c] for c in s]

def decode(l: list[int]) -> str:
    return ''.join([itoc[i] for i in l])

def enc2tnsr(l: list[int]) -> torch.Tensor:
    return torch.tensor(l).long()

def enc2seq(l: list[int]) -> torch.Tensor:
    return F.one_hot(enc2tnsr(l), len_abc).float()

def enct2seq(t: torch.Tensor) -> torch.Tensor:
    return F.one_hot(t, len_abc).float()

def str2seq(s: str) -> torch.Tensor:
    encoded = torch.tensor(encode(s)).long()
    return F.one_hot(encoded, len_abc).float()

In [85]:
class AttentiveRNN(nn.Module):
    def __init__(self, in_dim, h_dim, window_size) -> None:
        super(AttentiveRNN, self).__init__()

        self.idx = 0
        self.window_size = window_size
        self.register_buffer('S', torch.zeros(window_size, h_dim))

        self.rnn = nn.RNN(in_dim, h_dim, batch_first=True)
        self.W = nn.Parameter(torch.rand(h_dim, h_dim))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        _, h = self.rnn(x)
        a = F.softmax(self.S @ h.T, dim=-2)
        c = torch.sum(a * self.S, dim=-2)
        a = F.sigmoid(h @ self.W @ c.T)
        g = F.sigmoid(h * c)
        self.S[self.idx] = a * (1-g) * h + (1-a) * g * c
        self.idx = (self.idx + 1) % self.window_size

In [86]:
arnn = AttentiveRNN(10, 5, 3)
X = [torch.rand(4, 10) for _ in range(5)]

In [87]:
for t in X:
    arnn(t)