In [None]:
import string

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
abc = ' ' + string.punctuation + string.digits + string.ascii_letters
len_abc = len(abc)

itoc = {i:c for i,c in enumerate(abc)}
ctoi = {c:i for i,c in enumerate(abc)}

def encode(s: str) -> list[int]:
    return [ctoi[c] for c in s]

def decode(l: list[int]) -> str:
    return ''.join([itoc[i] for i in l])

def enc2tnsr(l: list[int]) -> torch.Tensor:
    return torch.tensor(l).long()

def enc2seq(l: list[int]) -> torch.Tensor:
    return F.one_hot(enc2tnsr(l), len_abc).float()

def enct2seq(t: torch.Tensor) -> torch.Tensor:
    return F.one_hot(t, len_abc).float()

def str2seq(s: str) -> torch.Tensor:
    encoded = torch.tensor(encode(s)).long()
    return F.one_hot(encoded, len_abc).float()

print(len_abc, abc)

In [None]:
class WriteHead(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim) -> None:
        super(WriteHead, self).__init__()

        self.Wq = nn.Linear(in_dim, key_dim)
        self.Wk = nn.Linear(mem_dim, key_dim)
        self.Wv = nn.Linear(in_dim, mem_dim)
        self.Wg = nn.Parameter(torch.rand(mem_dim, n_mem))

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        x = x.unsqueeze(-2)
        Q, K, V = self.Wq(x), self.Wk(M), self.Wv(x)
        a = F.softmax(K @ Q.transpose(-1,-2), dim=-2)
        g = F.sigmoid(V @ self.Wg @ M)
        return a @ V * g


class ReadHead(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim) -> None:
        super(ReadHead, self).__init__()

        self.Wq = nn.Linear(in_dim, key_dim)
        self.Wk = nn.Linear(mem_dim, key_dim)
        self.Wv = nn.Linear(mem_dim, in_dim)
        self.Wg = nn.Parameter(torch.rand(in_dim, n_mem))

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        x = x.unsqueeze(-2)
        Q, K, V = self.Wq(x), self.Wk(M), self.Wv(M)
        a = F.softmax(Q @ K.transpose(-1,-2), dim=-1)
        g = F.sigmoid(x @ self.Wg @ V)
        return torch.sum(a @ V * g, dim=-2)

In [None]:
class WriteHeadBlock(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim, n_heads) -> None:
        super(WriteHeadBlock, self).__init__()

        self.heads = nn.ModuleList([WriteHead(in_dim, n_mem, mem_dim, key_dim//n_heads) for _ in range(n_heads)])
        self.W_o = nn.Linear(n_heads*mem_dim, mem_dim)

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        M_w = [head(x, M) for head in self.heads]
        return self.W_o(torch.cat(M_w, dim=-1))


class ReadHeadBlock(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim, n_heads) -> None:
        super(ReadHeadBlock, self).__init__()

        self.heads = nn.ModuleList([ReadHead(in_dim, n_mem, mem_dim, key_dim//n_heads) for _ in range(n_heads)])
        self.W_o = nn.Linear(n_heads*in_dim, in_dim)

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        M_r = [head(x, M) for head in self.heads]
        return self.W_o(torch.cat(M_r, dim=-1))

In [None]:
class Memory(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim, n_heads) -> None:
        super(Memory, self).__init__()

        self.register_buffer('M', (-1+2*torch.rand(1, n_mem, mem_dim)))
        self.write_block = WriteHeadBlock(in_dim, n_mem, mem_dim, key_dim, n_heads)
        self.read_block = ReadHeadBlock(in_dim, n_mem, mem_dim, key_dim, n_heads)

    def reset_memory(self) -> None:
        self.M = -1+2*torch.rand(self.M.shape)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.shape[0] > 1 or self.M.shape[0] != x.shape[0]:
            self.M = self.M.expand(x.shape[0], *self.M.shape[1:])

        self.M = self.M + self.write_block(x, self.M)
        return self.read_block(x, self.M)

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, emb_dim, n_heads, h_dim) -> None:
        super(TransformerBlock, self).__init__()

        self.mha = nn.MultiheadAttention(emb_dim, n_heads, batch_first=False)
        self.mlp = nn.Sequential(nn.Linear(emb_dim, h_dim), nn.ReLU(), nn.Linear(h_dim, emb_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        attn_out, _ = self.mha(x, x, x)
        return self.mlp(x + attn_out)

In [None]:
class RNNNet(nn.Module):
    def __init__(self, rnns_args: tuple, mem_args: tuple) -> None:
        super(RNNNet, self).__init__()

        n_rnn, in_dim, h_dim, n_heads, th_dim = rnns_args
        n_mem, mem_dim, mem_key_dim, n_mem_heads = mem_args

        self.n_rnn = n_rnn
        self.h_dim = h_dim

        self.register_buffer('states', torch.zeros(n_rnn, 1, h_dim))
        self.rnns = nn.ModuleList([nn.RNN(in_dim, h_dim, batch_first=True) for _ in range(n_rnn)])
        self.tblock = TransformerBlock(h_dim, n_heads, th_dim)
        self.mems = nn.ModuleList([Memory(h_dim, n_mem, mem_dim, mem_key_dim, n_mem_heads) for _ in range(n_mem)])

    def reset_states(self) -> None:
        self.states = torch.zeros(self.n_rnn, 1, self.h_dim)

    def reset_memory(self) -> None:
        for mem in self.mems:
            mem.reset_memory()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        #TODO: handle batched input
        
        c = torch.cat([rnn(x, s)[1] for rnn, s in zip(self.rnns, self.states)], dim=0)
        self.states = (c + self.tblock(c)).unsqueeze(-2)
        
        for i in range(len(self.rnns)):
            self.states[i] = self.mems[i](self.states[i])

In [None]:
IN_DIM = 10
N_RNN = 3
H_DIM = 100
N_HEADS = 4
TH_DIM = 100

N_MEM = 3
MEM_DIM = 100
MEM_KEY_DIM = 10
N_MEM_HEADS = 2

rnn_args = (N_RNN, IN_DIM, H_DIM, N_HEADS, TH_DIM)
mem_args = (N_MEM, MEM_DIM, MEM_KEY_DIM, N_MEM_HEADS)

model = RNNNet(rnn_args, mem_args)
optim = torch.optim.SGD(model.parameters())

seq = torch.ones(10, IN_DIM); print(seq.shape)
chunks = torch.split(seq, 8); print(len(chunks), chunks[0].shape)

batch = torch.stack([seq, seq]); print(batch.shape)
batch_chunks = torch.split(batch, 8, dim=1); print(len(batch_chunks), batch_chunks[0].shape)

In [None]:
for chnk in chunks:
    model.states = torch.zeros(N_RNN, 1, H_DIM)
    model.reset_memory()
    
    optim.zero_grad()
    model(chnk)
    model.states.sum().backward()
    optim.step()

In [None]:
for bchnk in batch_chunks:
    model.states = torch.zeros(N_RNN, 1, 2, H_DIM)
    model.reset_memory()
    
    optim.zero_grad()
    model(bchnk)
    model.states.sum().backward()
    optim.step()

In [61]:
IN_DIM = len_abc
N_RNN = 4
H_DIM = 64
N_HEADS = 8
TH_DIM = 64

N_MEM = 32
MEM_DIM = 32
MEM_KEY_DIM = 16
N_MEM_HEADS = 4

rnns_args = (N_RNN, IN_DIM, H_DIM, N_HEADS, TH_DIM)
mem_args = (N_MEM, MEM_DIM, MEM_KEY_DIM, N_MEM_HEADS)
model = RNNNet(rnns_args, mem_args)

txt = 'The quick brown fox jumps over the lazy dog.'
chunks = [encode(chnk) for chnk in txt.split(' ')]

model.reset_states()
model.reset_memory()
for chnk in chunks:
    seq = enc2seq(chnk)
    model.forward(seq)