In [None]:
import string

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
abc = ' ' + string.punctuation + string.digits + string.ascii_letters
len_abc = len(abc)
print(f'{abc=}\n{len_abc=}')

itoc = {i:c for i,c in enumerate(abc)}
ctoi = {c:i for i,c in enumerate(abc)}

def encode(s: str) -> list[int]:
    return [ctoi[c] for c in s]

def decode(l: list[int]) -> str:
    return ''.join([itoc[i] for i in l])

def enc2tnsr(l: list[int]) -> torch.Tensor:
    return torch.tensor(l).long()

def enc2seq(l: list[int]) -> torch.Tensor:
    return F.one_hot(enc2tnsr(l), len_abc).float()

def enct2seq(t: torch.Tensor) -> torch.Tensor:
    return F.one_hot(t, len_abc).float()

def str2seq(s: str) -> torch.Tensor:
    encoded = torch.tensor(encode(s)).long()
    return F.one_hot(encoded, len_abc).float()

In [None]:
class AttentionHead(nn.Module):
    def __init__(self, in_dim, key_dim) -> None:
        super(AttentionHead, self).__init__()

        self.attn_scale = key_dim**-0.5

        self.q = nn.Linear(in_dim, key_dim, bias=False)
        self.k = nn.Linear(in_dim, key_dim, bias=False)
        self.v = nn.Linear(in_dim, in_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        Q, K, V = self.q(x), self.k(x), self.v(x)
        A = (Q @ K.transpose(1,2)) * self.attn_scale
        A = F.softmax(A, dim=2)
        return A @ V


class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, key_dim, in_dim) -> None:
        super(MultiHeadAttention, self).__init__()

        self.heads = nn.ModuleList([AttentionHead(in_dim, key_dim//n_heads) for _ in range(n_heads)])
        self.proj = nn.Linear(n_heads*in_dim, in_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        t = torch.cat([h(x) for h in self.heads], dim=2)
        return self.proj(t)


class TransformerBlock(nn.Module):
    def __init__(self, in_dim, key_dim, n_heads, h_dim) -> None:
        super(TransformerBlock, self).__init__()

        self.attention = MultiHeadAttention(n_heads, key_dim, in_dim)
        self.mlp = nn.Sequential(nn.Linear(in_dim, h_dim), nn.ReLU(), nn.Linear(h_dim, in_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attention(x)
        return x + self.mlp(x)

In [None]:
class WriteHead(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim) -> None:
        super(WriteHead, self).__init__()

        self.scaler = key_dim**-0.5

        self.q = nn.Linear(in_dim, key_dim, bias=False)
        self.k = nn.Linear(mem_dim, key_dim, bias=False)
        self.v = nn.Linear(in_dim, mem_dim, bias=False)
        self.g = nn.Parameter(torch.rand(mem_dim, n_mem))

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        Q, K, V = self.q(x), self.k(M), self.v(x)
        A = F.softmax(K @ Q.transpose(-1,-2) * self.scaler, dim=-2)
        G = F.sigmoid(V @ self.g @ M)
        return A @ V * G + (1-A) * M * (1-G)


class ReadHead(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim) -> None:
        super(ReadHead, self).__init__()

        self.scaler = key_dim**-0.5

        self.q = nn.Linear(in_dim, key_dim, bias=False)
        self.k = nn.Linear(mem_dim, key_dim, bias=False)
        self.v = nn.Linear(mem_dim, in_dim, bias=False)
        self.g = nn.Parameter(torch.rand(in_dim, n_mem))

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        Q, K, V = self.q(x), self.k(M), self.v(M)
        A = F.softmax(Q @ K.transpose(-1,-2) * self.scaler, dim=-1)
        G = F.sigmoid(x @ self.g @ V)
        return torch.sum(A @ V * G + (1-G) * x, dim=-2)

In [None]:
class MultiHeadWrite(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim, n_heads) -> None:
        super(MultiHeadWrite, self).__init__()

        self.heads = nn.ModuleList([WriteHead(in_dim, n_mem, mem_dim, key_dim//n_heads) for _ in range(n_heads)])
        self.proj = nn.Linear(n_heads*mem_dim, mem_dim)

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        M_w = [head(x, M) for head in self.heads]
        return self.proj(torch.cat(M_w, dim=-1))


class MultiHeadRead(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim, n_heads) -> None:
        super(MultiHeadRead, self).__init__()

        self.heads = nn.ModuleList([ReadHead(in_dim, n_mem, mem_dim, key_dim//n_heads) for _ in range(n_heads)])
        self.proj = nn.Linear(n_heads*in_dim, in_dim)

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        M_r = [head(x, M) for head in self.heads]
        return self.proj(torch.cat(M_r, dim=-1))

In [None]:
class Memory(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim, n_heads) -> None:
        super(Memory, self).__init__()

        self.register_buffer('M', torch.rand(1, n_mem, mem_dim))
        self.write_block = MultiHeadWrite(in_dim, n_mem, mem_dim, key_dim, n_heads)
        self.read_block = MultiHeadRead(in_dim, n_mem, mem_dim, key_dim, n_heads)

    def reset_memory(self) -> None:
        self.M = torch.rand(self.M.shape)
        nn.init.xavier_normal_(self.M)

    def forward(self, x: torch.Tensor, op: str = None) -> torch.Tensor | None:
        if x.ndim == 2:
            x = x.unsqueeze(-2)
        
        if x.shape[0] > 1 and self.M.shape[0] != x.shape[0]:
            self.M = self.M.expand(x.shape[0], *self.M.shape[1:])

        if op == 'w' or op == None:
            self.M = self.write_block(x, self.M)

        if op == 'r' or op == None:
            return self.read_block(x, self.M)
        return None

In [None]:
IN_DIM = len_abc
N_MEM = 32
MEM_DIM = 32
KEY_DIM = 16
N_HEADS = 4

T_N_BLOCKS = 4
T_KEY_DIM = 16
T_N_HEADS = 4
T_H_DIM = 64

memory = Memory(IN_DIM, N_MEM, MEM_DIM, KEY_DIM, N_HEADS)

t_args = [MEM_DIM, T_KEY_DIM, T_N_HEADS, T_H_DIM]
t_modules = [TransformerBlock(*t_args) for _ in range(T_N_BLOCKS)]
transformer = nn.Sequential(*t_modules)

classifier = nn.Sequential(
    nn.Linear(len_abc, len_abc),
    nn.Softmax(dim=1)
)

print(sum([p.numel() for p in memory.parameters()]))
print(sum([p.numel() for p in transformer.parameters()]))

In [None]:
txt = 'The quick brown fox jumps over the lazy dog.'
encoded = encode(txt)

In [None]:
# UNBATCHED INPUT

memory.reset_memory()
for code in encoded:
    seq = enc2seq([code])
    memory(seq, op='w')
    memory.M = transformer(memory.M)

In [None]:
# BATCHED INPUT

memory.reset_memory()
for code in encoded:
    seq = enc2seq(code)
    seq = seq.expand(32, *seq.shape)
    memory(seq, op='w')
    memory.M = transformer(memory.M)

In [None]:
# SEQUENCE GENERATION TEST

n_tokens = 100
block_size = 10
result = []
i, j = 0, 0

with torch.no_grad():
    memory.reset_memory()

    for _ in range(n_tokens):
        if j == block_size:
            memory.M = transformer(memory.M)
            j = 0

        m_r = memory(enc2seq([i]), op=None)
        p = classifier(m_r)[0]
        i = torch.multinomial(p, 1).item()
        result.append(i)
        
        j += 1

print(decode(result))