In [104]:
import string

import torch
import torch.nn as nn
import torch.nn.functional as F

In [105]:
abc = ' ' + string.punctuation + string.digits + string.ascii_letters
len_abc = len(abc)
print(len_abc, abc)

itoc = {i:c for i,c in enumerate(abc)}
ctoi = {c:i for i,c in enumerate(abc)}

def encode(s: str) -> list[int]:
    return [ctoi[c] for c in s]

def decode(l: list[int]) -> str:
    return ''.join([itoc[i] for i in l])

def enc2tnsr(l: list[int]) -> torch.Tensor:
    return torch.tensor(l).long()

def enc2seq(l: list[int]) -> torch.Tensor:
    return F.one_hot(enc2tnsr(l), len_abc).float()

def enct2seq(t: torch.Tensor) -> torch.Tensor:
    return F.one_hot(t, len_abc).float()

def str2seq(s: str) -> torch.Tensor:
    encoded = torch.tensor(encode(s)).long()
    return F.one_hot(encoded, len_abc).float()

95  !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ


In [106]:
class WriteHead(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim) -> None:
        super(WriteHead, self).__init__()

        self.scale = key_dim**(-0.5)

        self.q = nn.Linear(in_dim, key_dim)
        self.k = nn.Linear(mem_dim, key_dim)
        self.v = nn.Linear(in_dim, mem_dim)
        self.g = nn.Parameter(torch.rand(mem_dim, n_mem))
    
    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        Q, K, V = self.q(x), self.k(M), self.v(x)
        A = F.softmax(K @ Q.transpose(-1,-2) * self.scale, dim=-2)
        G = F.sigmoid(V @ self.g @ M)
        return A@V*G + (1-A)*M*(1-G)


class ReadHead(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim) -> None:
        super(ReadHead, self).__init__()

        self.scale = key_dim**(-0.5)

        self.q = nn.Linear(in_dim, key_dim)
        self.k = nn.Linear(mem_dim, key_dim)
        self.v = nn.Linear(mem_dim, in_dim)
        self.g = nn.Parameter(torch.rand(in_dim, n_mem))

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        Q, K, V = self.q(x), self.k(M), self.v(M)
        A = F.softmax(Q @ K.transpose(-1,-2) * self.scale, dim=-1)
        G = F.sigmoid(x @ self.g @ V)
        return torch.sum(A@V*G + (1-G)*x, dim=-2)

In [107]:
class WriteBlock(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim, n_heads) -> None:
        super(WriteBlock, self).__init__()

        self.heads = nn.ModuleList([WriteHead(in_dim, n_mem, mem_dim, key_dim//n_heads) for _ in range(n_heads)])
        self.proj = nn.Linear(n_heads*mem_dim, mem_dim)

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        M_w = [head(x, M) for head in self.heads]
        return self.proj(torch.cat(M_w, dim=-1))


class ReadBlock(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim, n_heads) -> None:
        super(ReadBlock, self).__init__()

        self.heads = nn.ModuleList([ReadHead(in_dim, n_mem, mem_dim, key_dim//n_heads) for _ in range(n_heads)])
        self.proj = nn.Linear(n_heads*in_dim, in_dim)

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        M_r = [head(x, M) for head in self.heads]
        return self.proj(torch.cat(M_r, dim=-1))

In [108]:
class Memory(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim, n_heads) -> None:
        super(Memory, self).__init__()

        self.register_buffer('M', torch.rand(1, n_mem, mem_dim))
        self.write_block = WriteBlock(in_dim, n_mem, mem_dim, key_dim, n_heads)
        self.read_block = ReadBlock(in_dim, n_mem, mem_dim, key_dim, n_heads)

    def reset_memory(self) -> None:
        self.M = torch.rand(1, *self.M.shape[1:])
        nn.init.xavier_normal_(self.M)

    def forward(self, x: torch.Tensor, write_only: bool = False) -> torch.Tensor | None:
        if x.ndim == 2:
            x = x.unsqueeze(-2)        
        
        if x.shape[0] > 1 and self.M.shape[0] != x.shape[0]:
            self.M = self.M.expand(x.shape[0], *self.M.shape[1:])
        
        self.M = self.write_block(x, self.M)

        if not write_only:
            return self.read_block(x, self.M)
        return None

In [109]:
class TransformerBlock(nn.Module):
    def __init__(self, emb_dim, n_heads, h_dim) -> None:
        super(TransformerBlock, self).__init__()

        self.mha = nn.MultiheadAttention(emb_dim, n_heads, batch_first=False)
        self.mlp = nn.Sequential(nn.Linear(emb_dim, h_dim), nn.ReLU(), nn.Linear(h_dim, emb_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        attn_out, _ = self.mha(x, x, x)
        return self.mlp(x + attn_out)

In [None]:
IN_DIM = len_abc
H_DIM = 32

N_MEM = 16
MEM_DIM = 16
KEY_DIM = 16
N_HEADS = 4

T_N_BLOCKS = 4
T_H_DIM = 32

projection = nn.Linear(IN_DIM, H_DIM)
t_args = (H_DIM, N_HEADS, T_H_DIM)
t_modules = [TransformerBlock(*t_args) for _ in range(T_N_BLOCKS)]
transformer = nn.Sequential(*t_modules)
memory = Memory(H_DIM, N_MEM, MEM_DIM, KEY_DIM, N_HEADS)
classifier = nn.Sequential(nn.Linear(H_DIM, len_abc), nn.Softmax(dim=0))

print('MEMORY SIZE:', sum([p.numel() for p in memory.parameters()]))
print('TRANSFORMER SIZE:', sum([p.numel() for p in transformer.parameters()]))

MEMORY SIZE: 14128
TRANSFORMER SIZE: 25344


In [111]:
txt = 'The quick brown fox jumps over the lazy dog.'
encoded = encode(txt)

In [None]:
n_tokens = 128
chnk_sz = 8
result = []
ctxs = None

memory.reset_memory()

# PREPOCESS SEQUENCE WITH TRANSFORMER AT CHUNK-LEVEL
for i in range(1, len(encoded), chnk_sz):
    chnk = projection(enc2seq(encoded[i:i+chnk_sz]))
    ctxs = transformer(chnk).unsqueeze(-2)
    for x in ctxs:
        memory.forward(x, write_only=True)

# MAKE CONTEXTS SAME LENGTH AS CHUNK_SIZE
if len(ctxs) < chnk_sz:
    filler = torch.zeros(chnk_sz-len(ctxs), 1, H_DIM)
    ctxs = torch.cat((filler, ctxs))

j = 0
# GENERATE SEQUENCE STARTING FROM CURRENT CHUNK CONTEXTS
for _ in range(n_tokens):
    if j == chnk_sz:
        for k in range(chnk_sz):
            ctxs[k] = memory.forward(ctxs[k])
        j = 0

    p = classifier(ctxs[-1])
    idx = torch.multinomial(p, 1, True).item()
    result.append(idx)

    # 1) add generated token to context window in circular fashion
    # 2) update contexts via transformer
    ctxs[j] = projection(enc2seq([idx]))
    ctxs = transformer(ctxs)
    j += 1

print(decode(result))

`^oHUFPIVzm|V)|8wOC-?tQt su6%kT2;R'ThIC~rQ2Y$$]o&,{6;]F=.Kl;/6jiCSJqv-i=C1lXs4Sa)&wU3~\^"'[BB88t leO2$P`9zM=2$zhx_pjP^hEH`9d cx8
