In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import string

In [21]:
abc = ' ' + string.punctuation + string.digits + string.ascii_letters
len_abc = len(abc)
print(f'{abc=}\n{len_abc=}')

itoc = {i:c for i,c in enumerate(abc)}
ctoi = {c:i for i,c in enumerate(abc)}

def encode(s: str) -> list[int]:
	return [ctoi[c] for c in s]

def decode(l: list[int]) -> str:
	return ''.join([itoc[i] for i in l])

def enc2tnsr(l: list[int]) -> torch.Tensor:
	return torch.tensor(l).long()

def enc2seq(l: list[int]) -> torch.Tensor:
	return F.one_hot(enc2tnsr(l), len_abc).float()

def tnsr2seq(t: torch.Tensor) -> torch.Tensor:
	return F.one_hot(t, len_abc).float()

def str2seq(s: str) -> torch.Tensor:
	encoded = torch.tensor(encode(s)).long()
	return F.one_hot(encoded, len_abc).float()

abc=' !"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
len_abc=95


In [22]:
class WriteHead(nn.Module):
	def __init__(self, in_dim, n_mem, mem_dim, key_dim) -> None:
		super(WriteHead, self).__init__()

		self.scale = key_dim**(-0.5)

		self.q = nn.Linear(mem_dim, key_dim, bias=False)
		self.k = nn.Linear(in_dim, key_dim, bias=False)
		self.v = nn.Linear(in_dim, mem_dim, bias=False)
		self.G = nn.Parameter(torch.rand(mem_dim, n_mem))
		self.U = nn.Parameter(torch.rand(key_dim, key_dim))
	
	def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
		# print('// WRITE HEAD')
		# print(f'{x.shape=} {M.shape=}')
		Q, K, V = self.q(M), self.k(x), self.v(x)#; print(f'{Q.shape=} {K.shape=} {V.shape=}')
		A = F.softmax(Q @ self.U @ K.transpose(-1,-2) * self.scale, dim=-2)#; print(f'{A.shape=}')
		G = F.sigmoid(V @ self.G @ M)#; print(f'{G.shape=}')
		return A @ (G * V)


class ReadHead(nn.Module):
	def __init__(self, in_dim, n_mem, mem_dim, key_dim) -> None:
		super(ReadHead, self).__init__()

		self.scale = key_dim**(-0.5)

		self.q = nn.Linear(in_dim, key_dim, bias=False)
		self.k = nn.Linear(mem_dim, key_dim, bias=False)
		self.v = nn.Linear(mem_dim, in_dim, bias=False)
		self.G = nn.Parameter(torch.rand(in_dim, n_mem))
		self.U = nn.Parameter(torch.rand(key_dim, key_dim))

	def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
		# print('// READ HEAD')
		# print(f'{x.shape=} {M.shape=}')
		Q, K, V = self.q(x), self.k(M), self.v(M)#; print(f'{Q.shape=} {K.shape=} {V.shape=}')
		A = F.softmax(Q @ self.U @ K.transpose(-1,-2) * self.scale, dim=-1)#; print(f'{A.shape=}')
		G = F.sigmoid(x @ self.G @ V)#; print(f'{G.shape=}')
		return torch.sum(A @ (G * V), dim=-2)

In [23]:
class WriteBlock(nn.Module):
	def __init__(self, in_dim, n_mem, mem_dim, key_dim, n_heads) -> None:
		super(WriteBlock, self).__init__()

		self.heads = nn.ModuleList([WriteHead(in_dim, n_mem, mem_dim, key_dim//n_heads) for _ in range(n_heads)])
		self.proj = nn.Linear(n_heads*mem_dim, mem_dim)

	def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
		M_w = [head(x, M) for head in self.heads]
		return self.proj(torch.cat(M_w, dim=-1))


class ReadBlock(nn.Module):
	def __init__(self, in_dim, n_mem, mem_dim, key_dim, n_heads) -> None:
		super(ReadBlock, self).__init__()

		self.heads = nn.ModuleList([ReadHead(in_dim, n_mem, mem_dim, key_dim//n_heads) for _ in range(n_heads)])
		self.proj = nn.Linear(n_heads*in_dim, in_dim)

	def forward(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
		M_r = [head(x, M) for head in self.heads]
		return self.proj(torch.cat(M_r, dim=-1))

In [24]:
class Memory(nn.Module):
	def __init__(self, in_dim, n_mem, mem_dim, key_dim, n_heads) -> None:
		super(Memory, self).__init__()

		self.register_buffer('M', torch.rand(1, n_mem, mem_dim))
		self.writer = WriteBlock(in_dim, n_mem, mem_dim, key_dim, n_heads)
		self.reader = ReadBlock(in_dim, n_mem, mem_dim, key_dim, n_heads)

	def reset_memory(self) -> None:
		self.M = torch.rand(1, *self.M.shape[1:])
		nn.init.xavier_uniform_(self.M)

	def forward(self, x: torch.Tensor, op: str) -> torch.Tensor | None:
		if x.ndim == 2:
			x = x.unsqueeze(0)

		if op == 'w':
			self.M = self.M + self.writer(x, self.M)
		elif op == 'r':
			x = x.transpose(0, 1)
			return self.reader(x, self.M)
		return None

In [25]:
IN_DIM = len_abc
N_MEM = 32
MEM_DIM = 16
KEY_DIM = 16
N_HEADS = 4

mem = Memory(IN_DIM, N_MEM, MEM_DIM, KEY_DIM, N_HEADS)
print('model size:', sum([p.numel() for p in mem.parameters()]))

model size: 67283


In [26]:
txt = 'The quick brown fox jumps over the lazy dog.'
encoded = encode(txt)

In [27]:
# WRITING: UNBATCHED INPUT

mem.reset_memory()
print(f'{mem.M.shape=}')

seq = enc2seq(encoded)
mem.forward(seq, op='w')
print(f'{mem.M.shape=}')

mem.M.shape=torch.Size([1, 32, 16])
mem.M.shape=torch.Size([1, 32, 16])


In [28]:
# WRITING: BATCHED INPUT

BATCH_SIZE = 64

mem.reset_memory()
print(f'{mem.M.shape=}')

seq = enc2seq(encoded)
seq = seq.expand(BATCH_SIZE, *seq.shape)
mem.forward(seq, op='w')
print(f'{mem.M.shape=}')

mem.M.shape=torch.Size([1, 32, 16])
mem.M.shape=torch.Size([64, 32, 16])


In [29]:
# READING: UNBATCHED INPUT

mem.reset_memory()
print(f'{mem.M.shape=}')

x = torch.rand(1, IN_DIM)
m_r = mem.forward(x, op='r')

print(f'{m_r.shape=}')

mem.M.shape=torch.Size([1, 32, 16])
m_r.shape=torch.Size([1, 95])


In [30]:
# READING: BATCHED INPUT

BATCH_SIZE = 64

mem.reset_memory()
print(f'{mem.M.shape=}')

x = torch.rand(BATCH_SIZE, IN_DIM)
m_r = mem.forward(x, op='r')

print(f'{m_r.shape=}')

mem.M.shape=torch.Size([1, 32, 16])
m_r.shape=torch.Size([64, 95])


In [31]:
# WRITE THAN READ THAN BACKWARD(): UNBATCHED INPUT

mem.reset_memory()
print(f'{mem.M.shape=}')

seq = enc2seq(encoded)
mem.forward(seq, op='w')
print(f'{mem.M.shape=}')

x = torch.rand(1, IN_DIM)
m_r = mem.forward(x, op='r')
print(f'{m_r.shape=}')

m_r.sum().backward()

mem.M.shape=torch.Size([1, 32, 16])
mem.M.shape=torch.Size([1, 32, 16])
m_r.shape=torch.Size([1, 95])


In [32]:
# WRITE THAN READ THAN BACKWARD(): BATCHED INPUT

BATCH_SIZE = 64

mem.reset_memory()
print(f'{mem.M.shape=}')

seq = enc2seq(encoded)
batch = seq.expand(BATCH_SIZE, *seq.shape)
mem.forward(batch, op='w')
print(f'{mem.M.shape=}')

x = torch.rand(BATCH_SIZE, IN_DIM)
m_r = mem.forward(x, op='r')
print(f'{m_r.shape=}')

m_r.sum().backward()

mem.M.shape=torch.Size([1, 32, 16])
mem.M.shape=torch.Size([64, 32, 16])
m_r.shape=torch.Size([64, 95])
