In [2]:
import numpy as np
import torch
import os

from transformers import GPT2TokenizerFast
from datasets import load_dataset

tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', bos_token='<|sos|>', eos_token='<|eos|>', pad_token='<|pad|>')
print(tokenizer.convert_tokens_to_ids('<|sos|>'), tokenizer.convert_tokens_to_ids('<|eos|>'), tokenizer.convert_tokens_to_ids('<|pad|>'))
# sos = 50257, eos = 50258, pad = 50259

SOS_TOKEN = '<|sos|>'
EOS_TOKEN = '<|eos|>'
PAD_TOKEN = '<|pad|>'
SOS_TOKEN_ID = tokenizer.convert_tokens_to_ids('<|sos|>')
EOS_TOKEN_ID = tokenizer.convert_tokens_to_ids('<|eos|>')
PAD_TOKEN_ID = tokenizer.convert_tokens_to_ids('<|pad|>')

  from .autonotebook import tqdm as notebook_tqdm
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


50257 50258 50259


In [15]:
device = [
    torch.device('cuda:0'),
    torch.device('cuda:1'),
    torch.device('cuda:2'),
    torch.device('cuda:3'),
]

In [4]:
def tokenize_sentence(sentence, max_length=None):
    if max_length:
        return tokenizer(sentence, return_tensors='pt', padding='max_length', max_length=max_length, truncation=True).input_ids
    else:
        return tokenizer(sentence, return_tensors='pt').input_ids

In [20]:
class PositionwiseFeedForwardLayer(torch.nn.Module):
    def __init__(self, d_model: int, dropout: float):
        super().__init__()

        self.linear1 = torch.nn.Linear(d_model, 4 * d_model)
        self.linear2 = torch.nn.Linear(4 * d_model, d_model)
        self.relu = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, x):
        x = self.relu(self.linear1(x))
        x = self.dropout(self.linear2(x))

        # x shape == output shape
        return x

In [21]:
class Head(torch.nn.Module):
    def __init__(self, d_model: int, d_head: int, dropout: float):
        super().__init__()

        assert d_model % d_head == 0
        d_tensor = d_model // d_head
        self.d_tensor = d_tensor

        self.key = torch.nn.Linear(d_model, d_tensor)
        self.query = torch.nn.Linear(d_model, d_tensor)
        self.value = torch.nn.Linear(d_model, d_tensor)

        self.dropout = torch.nn.Dropout(dropout)
    
    def forward(self, q, k, v, mask=None):

        # q, k, v = (batch_size, seq_len, d_model)

        q, k = self.query(k), self.key(q)

        # q, k = (batch_size, seq_len, d_tensor)
        # kT = (batch_size, d_tensor, seq_len)

        wei = q @ k.transpose(-2, -1) * (self.d_tensor ** (-0.5)) # q*kT/sqrt(d_k)

        # wei = (batch_size, seq_len, seq_len)

        if mask is not None:
            wei.masked_fill(mask==0, -1e10)
        
        wei = torch.nn.functional.softmax(wei, dim=-1)
        v = self.value(v)

        # wei = (batch_size, seq_len, seq_len)
        # v = (batch_size, seq_len, d_tensor)

        out = wei @ v

        # out = (batch_size, seq_len, d_tensor): d_tensor * n_heads = d_model

        return out

In [22]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_head: int, dropout: float, num_gpus: int):
        super().__init__()

        assert d_model % d_head == 0
        assert n_heads % num_gpus == 0
        d_tensor = d_model // d_head
        self.d_tensor = d_tensor

        self.heads = torch.nn.ModuleList([
            Head(d_model=d_model, d_head=d_head, dropout=dropout) for _ in range(n_heads)
        ])
        self.linear = torch.nn.Linear(n_heads * d_tensor, d_model) # n_heads * d_tensor == d_model
        self.dropout = torch.nn.Dropout(dropout)
    
    def forward(self, q, k, v, src_mask=None):
        out = torch.cat([
            head(q, k, v, src_mask) for head in self.heads
        ])

        return out

In [23]:
class LayerNorm(torch.nn.Module):
    def __init__(self, d_model: int, eps=1e-12):
        super().__init__()

        self.gamma = torch.nn.Parameter(torch.ones(d_model))
        self.beta = torch.nn.Parameter(torch.zeros(d_model))
        self.eps = eps
    
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x. var(-1, unbiased=False, keepdim=True)

        out = (x - mean) * ((var + self.eps) ** (-0.5))
        out = self.gamma * out + self.beta

        return out

In [24]:
class DecoderLayer(torch.nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_head: int, dropout: float, device_num: int):
        super().__init__()

        self.device_num = device_num
        self.attention_layernorm = LayerNorm(d_model)
        self.feedforward_layernorm = LayerNorm(d_model)

        self.self_attention = MultiHeadAttention(d_model=d_model, n_heads=n_heads, d_head=d_head, dropout=dropout, num_gpus=4)
        self.positionwise_feedforward = PositionwiseFeedForwardLayer(d_model=d_model, dropout=dropout)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, trg, trg_mask=None):

        # trg = (batch_size, seq_len, d_model)
        # trg_mask = (batch_size, seq_len)

        trg = trg.to(device[self.device_num])
        trg_mask = trg_mask.to(device[self.device_num]) if trg_mask is not None else None
    
        # self attention with dropout
        _trg = self.dropout(self.self_attention(trg, trg, trg, trg_mask))

        # _trg = (batch_size, seq_len, d_model) == trg
        # add & norm with residual connection
        trg = self.attention_layernorm(trg + _trg)


        # trg = (batch_size, seq_len, d_model)
        # positionwise feedforward layer
        _trg = self.dropout(self.positionwise_feedforward(trg))
        trg = self.feedforward_layernorm(_trg + trg)

        # trg = (batch_size, seq_len, d_model)
        return trg

In [25]:
class Decoder(torch.nn.Module):
    def __init__(self, vocab_size: int, d_model: int, n_layers: int, n_heads: int, max_length: int, dropout: float, num_gpus: int):
        super().__init__()

        # positional encoding
        self.token_embedding = torch.nn.Embedding(vocab_size, d_model)
        self.position_embedding = torch.nn.Embedding(max_length, d_model)

        self.n_layers = n_layers
        self.gpu = n_layers // num_gpus # 3
        self.layers = torch.nn.ModuleList([
            *[DecoderLayer(d_model=d_model, n_heads=n_heads, d_head=d_model//n_heads, dropout=dropout, device_num=0).to(device[0]) for _ in range(self.gpu)],
            *[DecoderLayer(d_model=d_model, n_heads=n_heads, d_head=d_model//n_heads, dropout=dropout, device_num=1).to(device[1]) for _ in range(self.gpu)],
            *[DecoderLayer(d_model=d_model, n_heads=n_heads, d_head=d_model//n_heads, dropout=dropout, device_num=2).to(device[2]) for _ in range(self.gpu)],
            *[DecoderLayer(d_model=d_model, n_heads=n_heads, d_head=d_model//n_heads, dropout=dropout, device_num=3).to(device[3]) for _ in range(self.gpu)],
        ])

        self.fc_out = torch.nn.Linear(d_model, vocab_size)
        self.dropout = torch.nn.Dropout(dropout)
    
    def get_trg_mask(self, trg):
        # trg = (batch_size, seq_len)
        batch_size, seq_len = trg.shape
        trg_mask = torch.tril(torch.ones((seq_len, seq_len))).expand(batch_size, 1, seq_len, seq_len)

        # trg_mask = (batch_size, 1, seq_len, seq_len)
        return trg_mask
    
    def forward(self, trg):
        
        # trg = (batch_size, seq_len)
        # trg_mask = (batch_size, seq_len)

        batch_size, seq_len = trg.shape
        trg_mask = self.get_trg_mask(trg).to(device[0])

        pos = torch.arange(0, seq_len).unsqueeze(0).repeat(batch_size, 1).to(device[0])
        trg = self.dropout((self.token_embedding(trg) + self.position_embedding(pos)))

        # trg = (batch_size, seq_len, d_model)

        # Decoder layers
        for layer in self.layers:
            trg = layer(trg, trg_mask)
        
        # trg = (batch_size, seq_len, d_model)

        output = self.fc_out(trg)

        # output = (batch_size, seq_len, vocab_size)

        return output

In [26]:
class GPTModel(torch.nn.Module):
    def __init__(self, vocab_size: int, d_model: int, n_layers: int, n_heads: int, max_length: int, dropout: float, tokenizer, num_gpus: int=4):
        super().__init__()

        self.tokenizer = tokenizer
        self.decoder = Decoder(vocab_size=vocab_size, d_model=d_model, n_layers=n_layers, n_heads=n_heads, max_length=max_length, dropout=dropout, num_gpus=4)
    

    # def forward(self, trg, trg_mask):

In [28]:
# hyperparameters

# model hyperparameters (from GPT3 small)
n_layers = 12
d_model = 768
n_heads = 12
d_tensor = d_model // n_heads # => 64
d_head = 64
max_length = 1024

vocab_size = tokenizer.vocab_size + 3 # +3 for <sos>, <eos>, <pad>
dropout = 0.1
batch_size = 64
learning_rate = 2e-5
num_epochs = 5


print(f"d_tensor = {d_tensor}")
print(f"d_model = {d_model}")

d_tensor = 64
d_model = 768


In [29]:
model = Decoder(vocab_size=vocab_size, d_model=d_model, n_layers=n_layers, n_heads=n_heads, max_length=max_length, dropout=dropout, num_gpus=len(device))

print(model)
print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):_}")

Decoder(
  (token_embedding): Embedding(50260, 768)
  (position_embedding): Embedding(1024, 768)
  (layers): ModuleList(
    (0): DecoderLayer(
      (attention_layernorm): LayerNorm()
      (feedforward_layernorm): LayerNorm()
      (self_attention): MultiHeadAttention(
        (heads): ModuleList(
          (0): Head(
            (key): Linear(in_features=768, out_features=12, bias=True)
            (query): Linear(in_features=768, out_features=12, bias=True)
            (value): Linear(in_features=768, out_features=12, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): Head(
            (key): Linear(in_features=768, out_features=12, bias=True)
            (query): Linear(in_features=768, out_features=12, bias=True)
            (value): Linear(in_features=768, out_features=12, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (2): Head(
            (key): Linear(in_features=768, out_features=12, bias=True)
 

In [31]:
a = torch.randint(low=0, high=vocab_size, size=(batch_size, max_length)).to(device[0])
b = model(a)
print(b.shape)

TypeError: Decoder.forward() missing 1 required positional argument: 'trg_mask'