In [2]:
import string

import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
abc = ' ' + string.punctuation + string.digits + string.ascii_letters
len_abc = len(abc)
print(len_abc, abc)

itoc = {i:c for i,c in enumerate(abc)}
ctoi = {c:i for i,c in enumerate(abc)}

def encode(s: str) -> list[int]:
    return [ctoi[c] for c in s]

def decode(l: list[int]) -> str:
    return ''.join([itoc[i] for i in l])

def enc2tnsr(l: list[int]) -> torch.Tensor:
    return torch.tensor(l).long()

def enc2seq(l: list[int]) -> torch.Tensor:
    return F.one_hot(enc2tnsr(l), len_abc).float()

def enct2seq(t: torch.Tensor) -> torch.Tensor:
    return F.one_hot(t, len_abc).float()

def str2seq(s: str) -> torch.Tensor:
    encoded = torch.tensor(encode(s)).long()
    return F.one_hot(encoded, len_abc).float()

95  !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ


In [4]:
class WriteHead(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim) -> None:
        super(WriteHead, self).__init__()

        self.Wq = nn.Linear(in_dim, key_dim)
        self.Wk = nn.Linear(mem_dim, key_dim)
        self.Wv = nn.Linear(in_dim, mem_dim)
        self.Wg = nn.Parameter(torch.rand(mem_dim, n_mem))

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        Q, K, V = self.Wq(x), self.Wk(M), self.Wv(x)
        a = F.softmax(K @ Q.transpose(-1,-2), dim=-2)
        g = F.sigmoid(V @ self.Wg @ M)
        return a @ V * g + (1-a) * M * (1-g)


class ReadHead(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim) -> None:
        super(ReadHead, self).__init__()

        self.Wq = nn.Linear(in_dim, key_dim)
        self.Wk = nn.Linear(mem_dim, key_dim)
        self.Wv = nn.Linear(mem_dim, in_dim)
        self.Wg = nn.Parameter(torch.rand(in_dim, n_mem))

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        Q, K, V = self.Wq(x), self.Wk(M), self.Wv(M)
        a = F.softmax(Q @ K.transpose(-1,-2), dim=-1)
        g = F.sigmoid(x @ self.Wg @ V)
        return torch.sum(a @ V * g + (1-g) * x, dim=-2)

In [5]:
class WriteHeadBlock(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim, n_heads) -> None:
        super(WriteHeadBlock, self).__init__()

        self.heads = nn.ModuleList([WriteHead(in_dim, n_mem, mem_dim, key_dim//n_heads) for _ in range(n_heads)])
        self.W_o = nn.Linear(n_heads*mem_dim, mem_dim)

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        M_w = [head(x, M) for head in self.heads]
        return self.W_o(torch.cat(M_w, dim=-1))


class ReadHeadBlock(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim, n_heads) -> None:
        super(ReadHeadBlock, self).__init__()

        self.heads = nn.ModuleList([ReadHead(in_dim, n_mem, mem_dim, key_dim//n_heads) for _ in range(n_heads)])
        self.W_o = nn.Linear(n_heads*in_dim, in_dim)

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        M_r = [head(x, M) for head in self.heads]
        return self.W_o(torch.cat(M_r, dim=-1))

In [6]:
class Memory(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim, n_heads) -> None:
        super(Memory, self).__init__()

        self.register_buffer('M', torch.rand(1, n_mem, mem_dim))
        self.write_block = WriteHeadBlock(in_dim, n_mem, mem_dim, key_dim, n_heads)
        self.read_block = ReadHeadBlock(in_dim, n_mem, mem_dim, key_dim, n_heads)

    def reset_memory(self) -> None:
        self.M = torch.rand(self.M.shape)
        nn.init.xavier_uniform_(self.M)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.ndim == 2:
            x = x.unsqueeze(-2)
        
        if x.shape[0] > 1 and self.M.shape[0] != x.shape[0]:
            self.M = self.M.expand(x.shape[0], *self.M.shape[1:])

        self.M = self.write_block(x, self.M)
        return self.read_block(x, self.M)

In [13]:
IN_DIM = len_abc
H_DIM = 32
N_MEM = 32
MEM_DIM = 16
KEY_DIM = 16
N_HEADS = 4

rnn = nn.RNN(IN_DIM, H_DIM, batch_first=True)
mem = Memory(H_DIM, N_MEM, MEM_DIM, KEY_DIM, N_HEADS)
classifier = nn.Sequential(nn.Linear(H_DIM, len_abc), nn.Softmax(dim=0))

print(sum([p.numel() for p in rnn.parameters()]))
print(sum([p.numel() for p in mem.parameters()]))

4128
17200


In [244]:
txt = 'The quick brown fox jumps over the lazy dog.'
chunks = [encode(chnk) for chnk in txt.split(' ')]

In [245]:
# UNBATCHED INPUT

h = None
mem.reset_memory()
for chnk in chunks:
    seq = enc2seq(chnk)
    _, h = rnn(seq, h)#; print(f'{h.shape=}')
    m_r = mem(h)#; print(f'{m_r.shape=}')
    h += m_r#; print(f'{h.shape=}')

In [246]:
# BATCHED INPUT

h = None
mem.reset_memory()
for chnk in chunks:
    seq = enc2seq(chnk)
    seq = seq.expand(7, *seq.shape)
    _, h = rnn(seq, h)#; print(f'{h.shape=}')
    m_r = mem(h.transpose(0,1))#; print(f'{m_r.shape=}')
    h += m_r#; print(f'{h.shape=}')

In [255]:
n_tokens = 128
block_size = 8

result = []
i, j = 0, 0

h = None
mem.reset_memory()
with torch.no_grad():
    for t in range(n_tokens):
        if j == block_size:
            h = mem.forward(h)
            j = 0
        
        _, h = rnn(enc2seq([i]), h)
        p = classifier(h)
        i = torch.multinomial(p, 1, True)
        result.append(i.item())

print(decode(result))

RwX$.2_m3xU.Rpzm8n;>:|M(#{1f\n!-KQ4bNQ\ShPx(xd<Q5l8AjK>.S2E*` \n7<J^P}8x6TM7<IO.-U21/o/0`mcXrovD3G,IPp,P;m_Ny Er&;X!9y[]t,D;1JvA
