In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import string

In [None]:
abc = ' ' + string.punctuation + string.digits + string.ascii_letters
len_abc = len(abc)
print(len_abc, abc)

itoc = {i:c for i,c in enumerate(abc)}
ctoi = {c:i for i,c in enumerate(abc)}

def encode(s: str) -> list[int]:
    return [ctoi[c] for c in s]

def decode(l: list[int]) -> str:
    return ''.join([itoc[i] for i in l])

def enc2tnsr(l: list[int]) -> torch.Tensor:
    return torch.tensor(l).long()

def enc2seq(l: list[int]) -> torch.Tensor:
    return F.one_hot(enc2tnsr(l), len_abc).float()

def enct2seq(t: torch.Tensor) -> torch.Tensor:
    return F.one_hot(t, len_abc).float()

def str2seq(s: str) -> torch.Tensor:
    encoded = torch.tensor(encode(s)).long()
    return F.one_hot(encoded, len_abc).float()

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, emb_dim, n_heads, h_dim) -> None:
        super(TransformerBlock, self).__init__()

        self.mha = nn.MultiheadAttention(emb_dim, n_heads, batch_first=False)
        self.mlp = nn.Sequential(nn.Linear(emb_dim, h_dim), nn.ReLU(), nn.Linear(h_dim, emb_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        attn_out, _ = self.mha(x, x, x)
        return self.mlp(x + attn_out)


class RNNNet(nn.Module):
    def __init__(self, n_rnn, in_dim, h_dim, out_dim, n_heads, th_dim) -> None:
        super(RNNNet, self).__init__()

        self.n_rnn = n_rnn
        self.h_dim = h_dim

        self.register_buffer('states', torch.zeros(n_rnn, 1, h_dim))
        self.rnns = nn.ModuleList([nn.RNN(in_dim, h_dim, batch_first=True) for _ in range(n_rnn)])
        self.tblock = TransformerBlock(h_dim, n_heads, th_dim)
        self.classifier = nn.Sequential(nn.Linear(n_rnn*h_dim, out_dim), nn.Softmax(dim=0))

    def reset_states(self) -> None:
        shape = (self.n_rnn, 1, self.h_dim)
        self.states = torch.zeros(shape)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.ndim == 3:
            n_seq, seq_len, emb_dim = x.shape
            shape = (self.n_rnn, 1, n_seq, self.h_dim)
            self.states = torch.zeros(shape)

        c = torch.cat([rnn(x, s)[1] for rnn, s in zip(self.rnns, self.states)], dim=0)
        self.states = (c + self.tblock(c)).unsqueeze(-2)
    
    @torch.no_grad()
    def generate(self, t: int, block_size: int) -> list[int]:
        result = [0]
        i, j = 0, 0

        for _ in range(t):
            if j == block_size:
                self.states = self.states + self.tblock(self.states)
                j = 0
            
            h = torch.cat([rnn(enc2seq([i]), s)[1] for rnn, s in zip(self.rnns, self.states)], dim=-1) 
            p = self.classifier(h)
            j += 1

            i = torch.multinomial(p, 1, True)
            result.append(i.item())
        
        return result

In [None]:
IN_DIM = len_abc
N_RNN = 4
H_DIM = 16
N_HEADS = 4
TH_DIM = 32
OUT_DIM = len_abc

model = RNNNet(N_RNN, IN_DIM, H_DIM, OUT_DIM, N_HEADS, TH_DIM)

In [None]:
# UNBATCHED INPUT

txt = 'The quick brown fox jumps over the lazy dog.'
chunks = [encode(chnk) for chnk in txt.split(' ')]

model.reset_states()
for chnk in chunks:
    seq = enc2seq(chnk)
    model.forward(seq)

In [None]:
# BATCHED INPUT

txt = 'The quick brown fox jumps over the lazy dog.'
chunks = [encode(chnk) for chnk in txt.split(' ')]

model.reset_states()
for chnk in chunks:
    seq = enc2seq(chnk)
    seq = seq.expand(8, *seq.shape)
    model.forward(seq)

In [None]:
model.reset_states()
enc = model.generate(128, 8)
print(decode(enc))