In [1]:
import math
import torch 
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# # tiny corpus
# text = ("hello world. \n"
#         "hello transformer. \n"
#         "attention is all you need. \n"
#         "hello attention. \n"
#         "transformers generate text. \n")

In [37]:
text = ("sawubona mfo. \n"
        "sawubona transformer. \n"
        "sifuna i attention. \n"
        "sawubona attention. \n"
        "transformers zikhiqiza umbhalo. \n")

In [39]:
# build vocab (characters)
chars = sorted(list(set(text)))
enum_chars = {ch:i for i, ch in enumerate(chars)} #to mimic transformers tokenizer, we will need to tokenize subwords later
item_chars = {i:ch for ch,i in enum_chars.items()}
#vocab size from the corpus
vocab_size = len(chars)

In [40]:
def encode(s):
    return torch.tensor([enum_chars[c] for c in s], dtype=torch.long)

In [41]:
def decode(ids):
    return "".join([item_chars[i] for i in ids])

In [42]:
input_data = encode(text)

In [43]:
class TinyGPT(nn.Module):
    def __init__(self, vocab_size, d_model = 64, nhead = 8, num_layers = 2, dim_feedforward = 128, block_size = 128, dropout = 0.1):
        super().__init__()
        self.block_size = block_size
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.positional_embedding = nn.Embedding(block_size, d_model)

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,dropout=dropout,activation='gelu',batch_first=True)
        self.tr = nn.TransformerEncoder(encoder_layer, num_layers = num_layers)
        self.lm_head = nn.Linear(d_model, vocab_size)

    def forward(self, idx):
        B, T = idx.shape
        if T > self.block_size:
            idx = idx[:, -self.block_size:]
            T = self.block_size

        position_ids = torch.arange(T, device = idx.device).unsqueeze(0).expand(B,T)
        x = self.token_embedding(idx) + self.positional_embedding(position_ids)

    # prevent looking ahead
        attn_mask = torch.triu(torch.ones(T,T, device=idx.device),diagonal=1).bool()
        x = self.tr(x, mask = attn_mask)
        logits = self.lm_head(x)
        return logits

In [44]:
@torch.no_grad()
def generate(model, start_text, max_new_chars=200, temperature = 1.0, top_k = None, device = 'cpu'):
    model.eval()
    idx = encode(start_text).unsqueeze(0).to(device)

    for _ in range(max_new_chars):
        logits = model(idx)[:,-1,:] / max(temperature, 1e-6)
        if top_k is not None:
            v, ix = torch.topk(logits, k = top_k, dim = 1)
            mask = torch.full_like(logits, float("-inf"))
            mask.scatter_(1, ix, v)
            logits = mask

        probs = F.softmax(logits, dim = 1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim = 1)

    return decode(idx.squeeze(0).tolist())

In [45]:
#Training loop 
def get_batch(data, batch_size = 32, block_size = 64, device = 'cpu'):
    ix = torch.randint(0, len(data) - block_size -1, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix]).to(device)
    y = torch.stack([data[i+1:i+block_size+1] for i in ix]).to(device)
    return x, y

In [48]:
def train():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    torch.manual_seed(0)

    model = TinyGPT(vocab_size, nhead = 4, num_layers = 2, dim_feedforward = 128, block_size = 128).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr = 3e-4)

    steps = 1200
    batch_size = 64
    block_size = 64

    model.train()
    for step in range(1, steps + 1):
        x, y = get_batch(input_data, batch_size, block_size, device = device)
        logits = model(x)
        loss = F.cross_entropy(logits.reshape(-1, vocab_size), y.reshape(-1))

        opt.zero_grad()
        loss.backward()
        opt.step()

        if step % 200 == 0:
            ppl = math.exp(loss.item())
            print(f"step {step:4d} | loss {loss.item():.3f} | ppl {ppl:.2f}")

    #generate sample text
    print("\n--- Generated text ---")
    for prompt in ["sawubona", "zikhiqiza", "transform"]:
        out = generate(model, prompt, max_new_chars = 120, temperature = 0.9, top_k = 10, device = device)
        print(f"\nPrompt: {prompt!r}\n{out}")

In [49]:
train()

step  200 | loss 0.953 | ppl 2.59
step  400 | loss 0.298 | ppl 1.35
step  600 | loss 0.119 | ppl 1.13
step  800 | loss 0.078 | ppl 1.08
step 1000 | loss 0.052 | ppl 1.05
step 1200 | loss 0.043 | ppl 1.04

--- Generated text ---

Prompt: 'sawubona'
sawubona transformer. 
sifuna i attention. 
sawubona attention. 
a 
tion. ntention. 
trantrmention. 
tionsawuwubonsa zikhiontrsa

Prompt: 'zikhiqiza'
zikhiqiza attention. 
sawubona attention. 
transformers zikhizalosa ubon. umena a 
saunsawubontionsfumikha 
santr. 
tiqizifurs a 

Prompt: 'transform'
transformer. 
sifuna i attention. 
sawubona attention. 
transformens 
sformensa an. ziontionansansfomersfon. an. antionsanansfons
