In [148]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import string

In [149]:
abc = ' ' + string.punctuation + string.digits + string.ascii_letters
len_abc = len(abc)
print(len_abc, abc)

itoc = {i:c for i,c in enumerate(abc)}
ctoi = {c:i for i,c in enumerate(abc)}

def encode(s: str) -> list[int]:
    return [ctoi[c] for c in s]

def decode(l: list[int]) -> str:
    return ''.join([itoc[i] for i in l])

def enc2tnsr(l: list[int]) -> torch.Tensor:
    return torch.tensor(l).long()

def enc2seq(l: list[int]) -> torch.Tensor:
    return F.one_hot(enc2tnsr(l), len_abc).float()

def enct2seq(t: torch.Tensor) -> torch.Tensor:
    return F.one_hot(t, len_abc).float()

def str2seq(s: str) -> torch.Tensor:
    encoded = torch.tensor(encode(s)).long()
    return F.one_hot(encoded, len_abc).float()

95  !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ


In [150]:
class WriteHead(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim) -> None:
        super(WriteHead, self).__init__()

        self.scale = key_dim**(-0.5) # attention scaler
        self.shape = (-1, n_mem, -1) # V shape for concat gating

        self.q = nn.Linear(in_dim, key_dim)
        self.k = nn.Linear(mem_dim, key_dim)
        self.v = nn.Linear(in_dim, mem_dim)

        # self.g = nn.Parameter(torch.rand(mem_dim, n_mem)) # general-attention gating [1]
        self.g = nn.Linear(2*mem_dim, mem_dim) # concat-then-linear gating [2]
    
    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        # print('// WRITE HEAD')
        # print(f'{x.shape=} {M.shape=}')
        Q, K, V = self.q(x), self.k(M), self.v(x)#; print(f'{Q.shape=} {K.shape=} {V.shape=}')
        A = F.softmax(K @ Q.transpose(-1,-2) * self.scale, dim=-2)#; print(f'{A.shape=}')

        # ----- [1]
        # G = F.sigmoid(V @ self.g @ M)#; print(f'{G.shape=}')

        # ----- [2]
        MV = torch.cat((M, V.expand(self.shape)), dim=-1)#; print(f'{MV.shape=}')
        G = F.sigmoid(self.g(MV))#; print(f'{G.shape=}')

        return A@V*G + (1-A)*M*(1-G)


class ReadHead(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim) -> None:
        super(ReadHead, self).__init__()

        self.scale = key_dim**(-0.5) # attention scaler
        self.shape = (-1, n_mem, -1) # X shape for concat gating

        self.q = nn.Linear(in_dim, key_dim)
        self.k = nn.Linear(mem_dim, key_dim)
        self.v = nn.Linear(mem_dim, in_dim)

        # self.g = nn.Parameter(torch.rand(in_dim, n_mem)) # general-attention gating [1]
        self.g = nn.Linear(2*in_dim, in_dim) # concat-then-linear gating [2]

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        # print('// READ HEAD')
        # print(f'{x.shape=} {M.shape=}')
        Q, K, V = self.q(x), self.k(M), self.v(M)#; print(f'{Q.shape=} {K.shape=} {V.shape=}')
        A = F.softmax(Q @ K.transpose(-1,-2) * self.scale, dim=-1)#; print(f'{A.shape=}')

        # ----- [1]
        # G = F.sigmoid(x @ self.g @ V)#; print(f'{G.shape=}')

        # ----- [2]
        VX = torch.cat((V, x.expand(self.shape)), dim=-1)#; print(f'{VX.shape=}')
        G = F.sigmoid(self.g(VX))#; print(f'{G.shape=}')

        return torch.sum(A@V*G + (1-G)*x, dim=-2)

In [151]:
class WriteBlock(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim, n_heads) -> None:
        super(WriteBlock, self).__init__()

        self.heads = nn.ModuleList([WriteHead(in_dim, n_mem, mem_dim, key_dim//n_heads) for _ in range(n_heads)])
        self.proj = nn.Linear(n_heads*mem_dim, mem_dim)

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        M_w = [head(x, M) for head in self.heads]
        return self.proj(torch.cat(M_w, dim=-1))


class ReadBlock(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim, n_heads) -> None:
        super(ReadBlock, self).__init__()

        self.heads = nn.ModuleList([ReadHead(in_dim, n_mem, mem_dim, key_dim//n_heads) for _ in range(n_heads)])
        self.proj = nn.Linear(n_heads*in_dim, in_dim)

    def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        M_r = [head(x, M) for head in self.heads]
        return self.proj(torch.cat(M_r, dim=-1))

In [152]:
class Memory(nn.Module):
    def __init__(self, in_dim, n_mem, mem_dim, key_dim, n_heads) -> None:
        super(Memory, self).__init__()

        self.register_buffer('M', torch.rand(1, n_mem, mem_dim))
        self.write_block = WriteBlock(in_dim, n_mem, mem_dim, key_dim, n_heads)
        self.read_block = ReadBlock(in_dim, n_mem, mem_dim, key_dim, n_heads)

    def reset_memory(self) -> None:
        self.M = torch.rand(self.M.shape)
        nn.init.xavier_uniform_(self.M)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.ndim == 2:
            x = x.unsqueeze(-2)        
        
        if x.shape[0] > 1 and self.M.shape[0] != x.shape[0]:
            self.M = self.M.expand(x.shape[0], *self.M.shape[1:])
        
        self.M = self.write_block(x, self.M)
        return self.read_block(x, self.M)

In [153]:
IN_DIM = len_abc
N_MEM = 5
MEM_DIM = 16
KEY_DIM = 16
N_HEADS = 4

mem = Memory(IN_DIM, N_MEM, MEM_DIM, KEY_DIM, N_HEADS)
print(sum([p.numel() for p in mem.parameters()]))

128147


In [154]:
txt = 'The quick brown fox jumps over the lazy dog.'
encoded = encode(txt)

In [155]:
# UNBATCHED INPUT

mem.reset_memory()
for code in encoded:
    t = enc2seq([code])
    t = torch.rand(1,len_abc)
    m_r = mem(t)#; print(f'{m_r.shape=}')

In [156]:
# BATCHED INPUT

mem.reset_memory()
for code in encoded:
    t = enc2seq([code])
    t = t.expand(8, *t.shape)
    m_r = mem(t)#; print(f'{m_r.shape=}')