In [1]:
import string

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class AttentionHead(nn.Module):
    def __init__(self,emb_dim, ctx_size, in_dim) -> None:
        super(AttentionHead, self).__init__()

        self.emb_dim = emb_dim

        self.q = nn.Linear(in_dim, emb_dim, bias=False)
        self.k = nn.Linear(in_dim, emb_dim, bias=False)
        self.v = nn.Linear(in_dim, in_dim, bias=False)

        self.register_buffer('mask', torch.tril(torch.ones(ctx_size, ctx_size))==0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        Q, K, V = self.q(x), self.k(x), self.v(x)
        A = (Q @ K.transpose(1,2)) * (self.emb_dim)**-0.5
        A = F.softmax(A.masked_fill(self.mask, -torch.inf), dim=2)
        return A @ V

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, emb_dim, ctx_size, in_dim) -> None:
        super(MultiHeadAttention, self).__init__()

        self.heads = nn.ModuleList([AttentionHead(emb_dim//n_heads, ctx_size, in_dim) for i in range(n_heads)])
        self.proj = nn.Linear(n_heads*in_dim, in_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        t = torch.cat([h(x) for h in self.heads], dim=2)
        return self.proj(t)

In [4]:
class Block(nn.Module):
    def __init__(self, n_heads, emb_dim, h_dim, ctx_size, in_dim) -> None:
        super(Block, self).__init__()

        self.attention = MultiHeadAttention(n_heads, emb_dim, ctx_size, in_dim)
        self.mlp = nn.Sequential(nn.Linear(in_dim, h_dim), nn.ReLU(), nn.Linear(h_dim, in_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attention(x)
        return x + self.mlp(x)

In [5]:
abc = ' ' + string.punctuation + string.digits + string.ascii_letters
len_abc = len(abc)

itoc = {i:c for i,c in enumerate(abc)}
ctoi = {c:i for i,c in enumerate(abc)}

print(len_abc, abc)

95  !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ


In [6]:
def decode(l):
    return ''.join([itoc[i] for i in l])

def list_to_seq(l):
    t = torch.tensor(l).long()
    return F.one_hot(t, len_abc).float()

In [7]:
ctx_size = 4
in_dim = len_abc

n_blocks = 4
n_heads = 4
emb_dim = 16
h_dim = 32

In [8]:
args = [n_heads, emb_dim, h_dim, ctx_size, in_dim]
modules = [Block(*args) for i in range(n_blocks)]
model = nn.Sequential(*modules)

In [11]:
output = [0] * ctx_size

for t in range(128):
    x = list_to_seq(output[t:t+ctx_size]).unsqueeze(0)
    p = torch.exp(model(x)[0][0])
    i = torch.multinomial(p, 1, True)
    output.append(i.item())

print(decode(output[ctx_size:]))

v~)2'r`HBp 4_(P47e|?<EkB.hxpDsFkv(Ygu.unrAck7D\VNFO-asXmpfreN=3O=.uSN;e9usOeFIpA Iq`2}Z:oohb"pNS>Ee?"wZYz0&i\3/v6pV,DIBOY#kyARKb
