In [37]:
import string

import torch
import torch.nn as nn
import torch.nn.functional as F

In [38]:
abc = ' ' + string.punctuation + string.digits + string.ascii_letters
len_abc = len(abc)
print(f'{abc=}\n{len_abc=}')

itoc = {i:c for i,c in enumerate(abc)}
ctoi = {c:i for i,c in enumerate(abc)}

def encode(s: str) -> list[int]:
    return [ctoi[c] for c in s]

def decode(l: list[int]) -> str:
    return ''.join([itoc[i] for i in l])

def enc2tnsr(l: list[int]) -> torch.Tensor:
    return torch.tensor(l).long()

def enc2seq(l: list[int]) -> torch.Tensor:
    return F.one_hot(enc2tnsr(l), len_abc).float()

def enct2seq(t: torch.Tensor) -> torch.Tensor:
    return F.one_hot(t, len_abc).float()

def str2seq(s: str) -> torch.Tensor:
    encoded = torch.tensor(encode(s)).long()
    return F.one_hot(encoded, len_abc).float()

abc=' !"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
len_abc=95


In [39]:
class AttentionHead(nn.Module):
    def __init__(self, in_dim, key_dim) -> None:
        super(AttentionHead, self).__init__()

        self.scaler = key_dim**-0.5

        self.q = nn.Linear(in_dim, key_dim, bias=False)
        self.k = nn.Linear(in_dim, key_dim, bias=False)
        self.v = nn.Linear(in_dim, in_dim, bias=False)

    def forward(self, Xq: torch.Tensor, Xk: torch.Tensor, Xv: torch.Tensor, masked: bool = False) -> torch.Tensor:
        Q, K, V = self.q(Xq), self.k(Xk), self.v(Xv)
        A = (Q @ K.transpose(1,2)) * self.scaler
        if masked:
            mask = torch.tril(torch.ones(len(Xq), len(Xq)))==0
            A = A.masked_fill(mask, -torch.inf)
        A = F.softmax(A, dim=2)
        return A @ V

class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, in_dim, key_dim) -> None:
        super(MultiHeadAttention, self).__init__()

        self.heads = nn.ModuleList([AttentionHead(in_dim, key_dim//n_heads) for _ in range(n_heads)])
        self.proj = nn.Linear(n_heads*in_dim, in_dim)

    def forward(self, Xq: torch.Tensor, Xk: torch.Tensor, Xv: torch.Tensor, masked: bool = False) -> torch.Tensor:
        V = torch.cat([h(Xq, Xk, Xv, masked) for h in self.heads], dim=2)
        return self.proj(V)

In [40]:
class EncoderBlock(nn.Module):
    def __init__(self, n_heads, in_dim, key_dim, h_dim) -> None:
        super(EncoderBlock, self).__init__()
    
        self.attention = MultiHeadAttention(n_heads, in_dim, key_dim)
        self.mlp = nn.Sequential(nn.Linear(in_dim, h_dim), nn.ReLU(), nn.Linear(h_dim, in_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attention(x, x, x)
        return x + self.mlp(x)

In [41]:
class DecoderBlock(nn.Module):
    def __init__(self, n_heads, in_dim, key_dim, h_dim) -> None:
        super(DecoderBlock, self).__init__()

        self.self_attn = MultiHeadAttention(n_heads, in_dim, key_dim)
        self.cross_attn = MultiHeadAttention(n_heads, in_dim, key_dim)
        self.mlp = nn.Sequential(nn.Linear(in_dim, h_dim), nn.ReLU(), nn.Linear(h_dim, in_dim))

    def forward(self, Xq: torch.Tensor, Xkv: torch.Tensor) -> torch.Tensor:
        Xq = Xq + self.self_attn(Xq, Xq, Xq, masked=True)
        Xq = Xq + self.cross_attn(Xq, Xkv, Xkv)
        return Xq + self.mlp(Xq)

In [42]:
IN_DIM = len_abc
KEY_DIM = 16
N_HEADS = 4
H_DIM = len_abc*2
N_BLOCKS = 4
OUT_DIM = len_abc

args = [N_HEADS, IN_DIM, KEY_DIM, H_DIM]

encoder_modules = [EncoderBlock(*args) for _ in range(N_BLOCKS)]
encoder = EncoderBlock(*args)

decoder_modules = [DecoderBlock(*args) for _ in range(N_BLOCKS)]
decoder = DecoderBlock(*args)

classifier = nn.Sequential(
    nn.Linear(IN_DIM, OUT_DIM),
    nn.Softmax(dim=1)
)

In [45]:
prompt = 'The quick brown fox jumps over the lazy dog'
enc = encode(prompt)
seq = enc2seq(enc).unsqueeze(0)

history = encoder.forward(seq)

out = [0]
for _ in range(100):
    seq = enc2seq(out).unsqueeze(0)
    dec_out = decoder.forward(seq, history)
    p = classifier(dec_out[0])[0]
    i = torch.multinomial(p,1).item()

    t = enc2seq([i]).unsqueeze(0)
    history = torch.cat((history, t), dim=1)
    out.append(i)

print(decode(out))

 A;xOJqSf3Z9/sB|#E@A67!cCCze}gm)]tAq gmlAx#m -wmwqe9]_fho1i{^p!}A1V9@`+A^no-?%HwGSm[M$uD^;Md]Ni; B|?{
