In [52]:
import string

import torch
import torch.nn as nn
import torch.nn.functional as F

In [53]:
abc = ' ' + string.punctuation + string.digits + string.ascii_letters
len_abc = len(abc)
print(len_abc, abc)

itoc = {i:c for i,c in enumerate(abc)}
ctoi = {c:i for i,c in enumerate(abc)}

def encode(s: str) -> list[int]:
    return [ctoi[c] for c in s]

def decode(l: list[int]) -> str:
    return ''.join([itoc[i] for i in l])

def enc2tnsr(l: list[int]) -> torch.Tensor:
    return torch.tensor(l).long()

def enc2seq(l: list[int]) -> torch.Tensor:
    return F.one_hot(enc2tnsr(l), len_abc).float()

def enct2seq(t: torch.Tensor) -> torch.Tensor:
    return F.one_hot(t, len_abc).float()

def str2seq(s: str) -> torch.Tensor:
    encoded = torch.tensor(encode(s)).long()
    return F.one_hot(encoded, len_abc).float()

95  !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ


In [54]:
class AttentionHead(nn.Module):
    def __init__(self, in_dim, key_idm) -> None:
        super(AttentionHead, self).__init__()

        self.attn_scale = key_idm**-0.5

        self.q = nn.Linear(in_dim, key_idm, bias=False)
        self.k = nn.Linear(in_dim, key_idm, bias=False)
        self.v = nn.Linear(in_dim, in_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        Q, K, V = self.q(x), self.k(x), self.v(x)
        A = (Q @ K.transpose(1,2)) * self.attn_scale
        A = F.softmax(A, dim=2)
        return A @ V


class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, key_dim, in_dim) -> None:
        super(MultiHeadAttention, self).__init__()

        self.heads = nn.ModuleList([AttentionHead(in_dim, key_dim//n_heads) for _ in range(n_heads)])
        self.proj = nn.Linear(n_heads*in_dim, in_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        t = torch.cat([h(x) for h in self.heads], dim=2)
        return self.proj(t)


class TransformerBlock(nn.Module):
    def __init__(self, in_dim, key_dim, n_heads, h_dim) -> None:
        super(TransformerBlock, self).__init__()

        self.attention = MultiHeadAttention(n_heads, key_dim, in_dim)
        self.mlp = nn.Sequential(nn.Linear(in_dim, h_dim), nn.ReLU(), nn.Linear(h_dim, in_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attention(x)
        return x + self.mlp(x)

In [55]:
class WriteHead(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim) -> None:
        super(WriteHead, self).__init__()

        self.q = nn.Linear(in_dim, key_dim)
        self.k = nn.Linear(mem_dim, key_dim)
        self.v = nn.Linear(in_dim, mem_dim)
        self.g = nn.Parameter(torch.rand(mem_dim, n_mem))

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        Q, K, V = self.q(x), self.k(M), self.v(x)
        A = F.softmax(K @ Q.transpose(-1,-2), dim=-2)
        G = F.sigmoid(V @ self.g @ M)
        return A @ V * G + (1-A) * M * (1-G)


class ReadHead(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim) -> None:
        super(ReadHead, self).__init__()

        self.q = nn.Linear(in_dim, key_dim)
        self.k = nn.Linear(mem_dim, key_dim)
        self.v = nn.Linear(mem_dim, in_dim)
        self.g = nn.Parameter(torch.rand(in_dim, n_mem))

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        Q, K, V = self.q(x), self.k(M), self.v(M)
        A = F.softmax(Q @ K.transpose(-1,-2), dim=-1)
        G = F.sigmoid(x @ self.g @ V)
        return torch.sum(A @ V * G + (1-G) * x, dim=-2)

In [56]:
class MultiHeadWrite(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim, n_heads) -> None:
        super(MultiHeadWrite, self).__init__()

        self.heads = nn.ModuleList([WriteHead(in_dim, n_mem, mem_dim, key_dim//n_heads) for _ in range(n_heads)])
        self.proj = nn.Linear(n_heads*mem_dim, mem_dim)

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        M_w = [head(x, M) for head in self.heads]
        return self.proj(torch.cat(M_w, dim=-1))


class MultiHeadRead(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim, n_heads) -> None:
        super(MultiHeadRead, self).__init__()

        self.heads = nn.ModuleList([ReadHead(in_dim, n_mem, mem_dim, key_dim//n_heads) for _ in range(n_heads)])
        self.proj = nn.Linear(n_heads*in_dim, in_dim)

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        M_r = [head(x, M) for head in self.heads]
        return self.proj(torch.cat(M_r, dim=-1))


class Memory(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim, n_heads) -> None:
        super(Memory, self).__init__()

        self.register_buffer('M', torch.rand(1, n_mem, mem_dim))
        self.write_heads = MultiHeadWrite(in_dim, n_mem, mem_dim, key_dim, n_heads)
        self.read_heads = MultiHeadRead(in_dim, n_mem, mem_dim, key_dim, n_heads)

    def reset_memory(self) -> None:
        self.M = torch.rand(self.M.shape)
        nn.init.xavier_uniform_(self.M)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.ndim == 2:
            x = x.unsqueeze(-2)
        
        if x.shape[0] > 1 and self.M.shape[0] != x.shape[0]:
            self.M = self.M.expand(x.shape[0], *self.M.shape[1:])

        self.M = self.write_heads(x, self.M)
        return self.read_heads(x, self.M)

In [57]:
IN_DIM = len_abc
N_MEM = 32
MEM_DIM = 16
KEY_DIM = 16
N_HEADS = 4
N_T_BLOCKS = 4
T_H_DIM = 64

mem = Memory(IN_DIM, N_MEM, MEM_DIM, KEY_DIM, N_HEADS)

transformer_args = [MEM_DIM, KEY_DIM, N_HEADS, T_H_DIM]
transformer_modules = [TransformerBlock(*transformer_args) for _ in range(N_T_BLOCKS)]
transformer = nn.Sequential(*transformer_modules)

print(sum([p.numel() for p in mem.parameters()]))
print(sum([p.numel() for p in transformer.parameters()]))

67663
18816


In [58]:
txt = 'The quick brown fox jumps over the lazy dog.'
encoded = encode(txt)

In [59]:
# UNBATCHED INPUT

mem.reset_memory()

for code in encoded:
    seq = enc2seq([code])
    m_r = mem(seq)

for _ in range(10):
    mem.M = transformer(mem.M)

In [60]:
# BATCHED INPUT

h = None
mem.reset_memory()

for code in encoded:
    seq = enc2seq(code)
    seq = seq.expand(16, *seq.shape)
    m_r = mem(seq)

for _ in range(10):
    mem.M = transformer(mem.M)