In [2]:
with open('grimms.txt', encoding='utf-8', errors='ignore') as f:
    text = f.read()


len(text)

537278

In [3]:
import string

allowed = set(string.printable)
text = ''.join(ch for ch in text if ch in allowed)

chars =sorted(set(list(text)))
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

def encode(s):
    return [stoi[c] for c in s]

def decode(l):
    return ''.join([itos[i] for i in l])

vocab_size = len(chars)
print(vocab_size)

82


In [4]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)

In [5]:
# split dataset
n = int(0.9 * len(data)) # 90% for training

train_data = data[:n]
test_data = data[n:]

print(train_data.shape, test_data.shape)

torch.Size([478924]) torch.Size([53214])


In [6]:
train_data[:100]

tensor([46, 34, 31,  1, 28, 44, 41, 46, 34, 31, 44, 45,  1, 33, 44, 35, 39, 39,
         1, 32, 27, 35, 44, 51,  1, 46, 27, 38, 31, 45,  0,  0,  0,  0,  0, 46,
        34, 31,  1, 33, 41, 38, 30, 31, 40,  1, 28, 35, 44, 30,  0,  0,  0, 27,
         1, 58, 60, 73, 75, 56, 64, 69,  1, 66, 64, 69, 62,  1, 63, 56, 59,  1,
        56,  1, 57, 60, 56, 76, 75, 64, 61, 76, 67,  1, 62, 56, 73, 59, 60, 69,
        10,  1, 56, 69, 59,  1, 64, 69,  1, 75])

In [7]:
context_length = 8
train_data[:context_length+1]

tensor([46, 34, 31,  1, 28, 44, 41, 46, 34])

In [8]:
x = train_data[:context_length]
y = train_data[1:context_length+1]

for t in range(context_length):
    context = x[:t+1]
    target = y[t]
    print(f'when input is {context} the target: {target}')

when input is tensor([46]) the target: 34
when input is tensor([46, 34]) the target: 31
when input is tensor([46, 34, 31]) the target: 1
when input is tensor([46, 34, 31,  1]) the target: 28
when input is tensor([46, 34, 31,  1, 28]) the target: 44
when input is tensor([46, 34, 31,  1, 28, 44]) the target: 41
when input is tensor([46, 34, 31,  1, 28, 44, 41]) the target: 46
when input is tensor([46, 34, 31,  1, 28, 44, 41, 46]) the target: 34


In [9]:
torch.manual_seed(42)
context_length = 8 # số token tối đa mô hình nhìn vào
batch_size = 32

def get_batch(split):
    source = train_data if split == 'train' else test_data
    ix = torch.randint(len(source) - context_length, (batch_size,))
    x = torch.stack([source[i:i+context_length] for i in ix])
    y = torch.stack([source[i+1:i+context_length+1] for i in ix])
    return x, y

In [10]:
xb, yb = get_batch('train')
xb.shape

torch.Size([32, 8])

In [11]:
vocab_size

82

In [12]:
import torch.nn as nn
import torch.nn.functional as F

n_embd = 64
token_embedding = nn.Embedding(len(chars), n_embd) # vocabsize x embedding_dimension
position_embedding = nn.Embedding(context_length, n_embd)

In [13]:
tok = token_embedding(xb) # B= 32, T = 8, C= 64
tok.shape

torch.Size([32, 8, 64])

In [14]:
pos = position_embedding(torch.arange(context_length))
pos.shape
# x = pok + tok # B, T, C

torch.Size([8, 64])

# self-attention

In [33]:
class SelfAttention(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.key   = nn.Linear(n_embd, n_embd, bias=False)
        self.query = nn.Linear(n_embd, n_embd, bias=False)
        self.value = nn.Linear(n_embd, n_embd, bias=False)

        self.register_buffer(
            'mask', torch.tril(torch.ones(context_length, context_length))
        )

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x) # (B, T, C)
        q = self.query(x) # (B, T, C)

        scores = q @ k.transpose(-2, -1) / (C ** 0.5) # (B, T, T)
        scores = scores.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
        weights = F.softmax(scores, dim=-1)

        v = self.value(x) # (B,T,C)
        out = weights @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

In [34]:
# transformer block
class TransformerBlock(nn.Module):
    def __init__(self, n_embd):
        super().__init__()

        self.attn = SelfAttention(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ff = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd)
        )
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x


In [35]:
# our minigpt
class MiniGPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_emb = nn.Embedding(len(chars), n_embd)
        self.pos_emb = nn.Embedding(context_length, n_embd)

        self.block = TransformerBlock(n_embd)
        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, len(chars))

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok = self.token_emb(idx)
        pos = self.pos_emb(torch.arange(T))
        x = tok + pos

        x = self.block(x)
        x = self.ln_f(x)
        logits = self.head(x)

        if targets is None:
            loss = None
        else:
            logits = logits.view(B*T, -1)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

In [36]:
model = MiniGPT()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

for step in range(3000):
    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 500 == 0:
        print(step, loss.item())

0 4.515272617340088
500 2.237753391265869
1000 2.066254138946533
1500 1.810591220855713
2000 2.1538896560668945
2500 1.9532157182693481


In [37]:
def generate(model, start_token, length=300):
    idx = torch.tensor([[start_token]])
    for _ in range(length):
        idx_cond = idx[:, -context_length:]
        logits, _ = model(idx_cond)
        probs = F.softmax(logits[:, -1, :], dim=-1)
        next_idx = torch.multinomial(probs, 1)
        idx = torch.cat([idx, next_idx], dim=1)
    return decode(idx[0].tolist())

start_char = 'T'        # hoặc ' ', '\n'
start_token = stoi[start_char]

print(generate(model, start_token, length=300))

The aboden wayhe some in would end a her werythe min andranby camee the frivihtE Fre ck, He fore to med; I gan
ruve, ablad cret sheid her in all onen
ho cred he bawerst a do ther what offore cuseld ther ith
her dades of offled before
roughted drablone that said he toove wich. The wors mazzere dnowen 
