In [3]:
import torch
from torch.nn import Linear, GELU, CrossEntropyLoss, Sequential, Conv2d, MaxPool2d, Module, Softmax, BatchNorm2d, Dropout, functional as F

Let's start with the Masked Self Attention

![Self Attention](images/self_attention.png)

In [1]:
from gpt2attention import SelfAttention

sa = SelfAttention(768, 12)

In [5]:
x = torch.randn(4, 50, 768)
sa.forward(x).shape

torch.Size([4, 50, 768])

In [6]:
MAX_CONTEXT = 128
dim = 768
n_heads = 12
embed_dim = 768
context = 50 # simulating the 50th word in the context
batch_size = 4

c_attn = Linear(dim, dim*3, bias=True) # W_q, W_k, W_v, that's why dim*3
c_proj = Linear(dim, dim, bias=True)

x = torch.randn(batch_size, context, embed_dim) # (batch_size, context, embed_dim)

In [34]:
def split_heads(x):
    return x.view(x.shape[0], x.shape[1], n_heads, dim//n_heads)

def merge_heads(x: torch.Tensor, num_heads, head_dim) -> torch.Tensor:
        x = x.contiguous()
        return x.view((x.shape[0], x.shape[1], num_heads * head_dim))

In [35]:
def attention(q, k, v, mask=None):
    w = torch.matmul(q.transpose(1,2), k.transpose(1, 2).transpose(2, 3))
    w = w / torch.sqrt(torch.tensor(k.shape[-1]).float())
    print(f'w.shape: {w.shape}')
    print(f'q.shape: {q.shape}')
    print(f'k.shape: {k.shape}')
    print(f'v.shape: {v.shape}')
    if mask is not None:
        w = w + mask
    query_len = q.shape[1]
    key_len = k.shape[1]
    # Implementing the mask
    causal_mask = torch.tril(torch.ones((query_len, key_len), dtype=torch.bool))
    mask_value = torch.finfo(w.dtype).min # represent -inf
    w = torch.where(causal_mask, w, mask_value)
    print(f'w.shape: {w.shape}')
    
    w = F.softmax(w, dim=-1)
    print(f'w.shape after softmax: {w.shape}')
    print(f'v.shape: {v.shape}')
    attn_output = torch.matmul(w, v.transpose(1, 2)).transpose(1, 2)
    print(f'attn_output.shape: {attn_output.shape}')
    return attn_output

In [36]:
# Forward operation
xqkv = c_attn(x)
queries, keys, values = xqkv.split(dim, dim=2)
queries = split_heads(queries)
keys = split_heads(keys)
values = split_heads(values)

attn_output = attention(queries, keys, values)
attn_output = merge_heads(attn_output, n_heads, dim//n_heads)
attn_output = c_proj(attn_output)
attn_output.shape

# 4,50,768


w.shape: torch.Size([4, 12, 50, 50])
q.shape: torch.Size([4, 50, 12, 64])
k.shape: torch.Size([4, 50, 12, 64])
v.shape: torch.Size([4, 50, 12, 64])
w.shape: torch.Size([4, 12, 50, 50])
w.shape after softmax: torch.Size([4, 12, 50, 50])
v.shape: torch.Size([4, 50, 12, 64])
attn_output.shape: torch.Size([4, 50, 12, 64])


torch.Size([4, 50, 768])

In [15]:
tokenizer.encode('Hello, my dog is cute')

[15496, 11, 616, 3290, 318, 13779]

In [1]:
from transformer import Transformer
from dataset import TextDataset
from torch.utils.data import DataLoader
import torch
import tiktoken

model_size = "gpt2"
tokenizer = tiktoken.get_encoding(model_size)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = Transformer(
            dim=768,
            n_heads=12,
            vocab_size=50257,
            n_layers=12,
            max_seq_len=128,
            device=device
        )

In [2]:
import os
from torch.utils.data.dataloader import default_collate

input_file_path = os.path.join("./data/shakespeare/", 'input.txt')
with open(input_file_path, 'r') as f:
    data = f.read()
n = len(data)

dataset = TextDataset(data, tokenizer, max_length=128, input_type="text")
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))

337920 this has a different length: 114, padding
128
128


### Training

In [4]:
import torch

optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)
num_epochs = 10

for epoch in range(num_epochs):
    for i, batch in enumerate(dataloader):
        inputs, labels = batch
        
        optimizer.zero_grad()
        logits, loss = model(inputs, labels=labels)
        loss.backward()
        optimizer.step()
        if i % 200 == 0:
            print(f"epoch {epoch} batch {i}: loss: {loss}")

epoch 0 batch 0: loss: 11.241231918334961
epoch 0 batch 200: loss: 5.63091516494751
epoch 0 batch 400: loss: 5.092859268188477
epoch 0 batch 600: loss: 4.970579147338867
epoch 1 batch 0: loss: 4.348081588745117
epoch 1 batch 200: loss: 4.332835674285889
epoch 1 batch 400: loss: 3.9893200397491455
epoch 1 batch 600: loss: 4.3360981941223145
epoch 2 batch 0: loss: 3.6253151893615723
epoch 2 batch 200: loss: 3.7650766372680664
epoch 2 batch 400: loss: 4.161316871643066
epoch 2 batch 600: loss: 3.8775484561920166
epoch 3 batch 0: loss: 3.393232822418213
epoch 3 batch 200: loss: 3.6106767654418945
epoch 3 batch 400: loss: 3.3115596771240234
epoch 3 batch 600: loss: 3.516842842102051
epoch 4 batch 0: loss: 1.9148331880569458
epoch 4 batch 200: loss: 2.240736484527588
epoch 4 batch 400: loss: 2.2979681491851807
epoch 4 batch 600: loss: 2.4867141246795654
epoch 5 batch 0: loss: 1.7106881141662598
epoch 5 batch 200: loss: 1.4227973222732544
epoch 5 batch 400: loss: 1.6919183731079102
epoch 5 ba

In [19]:
import torch.nn.functional as F

def generate(
            prompt,
            max_len=50,
            do_sample=True, 
            temperature=0.1, 
            top_k=0, 
            top_p=0.9, 
            repetition_penalty=1.0, 
            num_return_sequences=1, 
            batch_size=1, 
            device="cuda"):
        
        device = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')
        prompt_tokens = tokenizer.encode(prompt)

        for _ in range(num_return_sequences):
            generated = torch.tensor([prompt_tokens])
            generated = generated.to(device)
            prompt_len = len(prompt_tokens)

            for _ in range(max_len):
                with torch.no_grad():
                    outputs = model(generated)
                    next_token_logits = outputs[0][:, -1, :]
                    # for token in set(generated[0].tolist()):
                    #     next_token_logits[token] /= repetition_penalty
                    #next_token_logits = next_token_logits / temperature
                    #filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
                    if do_sample:
                        next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1)
                    # else:
                    #     next_token = torch.argmax(filtered_logits, dim=-1)
                    generated = torch.cat((generated, next_token), dim=1)

            result = generated[0].tolist()
            text = tokenizer.decode(result[prompt_len:])
            return prompt + text
        
print(generate("Help"))

Help me to the char of the character of my parent.

LUCENTIO:
I could not to a slaughter.

BAPTISTA:
What is to be your daughter,--

Prithee, good
