In [1]:
import numpy as np
import torch
import os

from transformers import GPT2TokenizerFast
from datasets import load_dataset

tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', bos_token='<|sos|>', eos_token='<|eos|>', pad_token='<|pad|>')
print(tokenizer.convert_tokens_to_ids('<|sos|>'), tokenizer.convert_tokens_to_ids('<|eos|>'), tokenizer.convert_tokens_to_ids('<|pad|>'))
# sos = 50257, eos = 50258, pad = 50259

SOS_TOKEN = '<|sos|>'
EOS_TOKEN = '<|eos|>'
PAD_TOKEN = '<|pad|>'
SOS_TOKEN_ID = tokenizer.convert_tokens_to_ids('<|sos|>')
EOS_TOKEN_ID = tokenizer.convert_tokens_to_ids('<|eos|>')
PAD_TOKEN_ID = tokenizer.convert_tokens_to_ids('<|pad|>')

  from .autonotebook import tqdm as notebook_tqdm
Downloading (…)olve/main/vocab.json: 100%|██████████| 1.04M/1.04M [00:00<00:00, 63.5MB/s]
Downloading (…)olve/main/merges.txt: 100%|██████████| 456k/456k [00:00<00:00, 65.9MB/s]
Downloading (…)/main/tokenizer.json: 100%|██████████| 1.36M/1.36M [00:00<00:00, 140MB/s]
Downloading (…)lve/main/config.json: 100%|██████████| 665/665 [00:00<00:00, 1.98MB/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


50257 50258 50259


In [3]:
device = [
    torch.device('cuda:0'),
    torch.device('cuda:1'),
    torch.device('cuda:2'),
    torch.device('cuda:3'),
]

In [None]:
def tokenize_sentence(sentence, max_length=None):
    if max_length:
        return tokenizer(sentence, return_tensors='pt', padding='max_length', max_length=max_length, truncation=True).input_ids
    else:
        return tokenizer(sentence, return_tensors='pt').input_ids

In [2]:
class PositionwiseFeedForwardLayer(torch.nn.Module):
    def __init__(self, d_model: int, dropout: float):
        super().__init__()

        self.linear1 = torch.nn.Linear(d_model, 4 * d_model)
        self.linear2 = torch.nn.Linear(4 * d_model, d_model)
        self.relu = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, x):
        x = self.relu(self.linear1(x))
        x = self.dropout(self.linear2(x))

        # x shape == output shape
        return x

In [5]:
class Head(torch.nn.Module):
    def __init__(self, d_model: int, d_head: int, dropout: float, device_num: int):
        super().__init__()

        assert d_model % d_head == 0
        d_tensor = d_model // d_head
        self.d_tensor = d_tensor

        self.key = torch.nn.Linear(d_model, d_tensor)
        self.query = torch.nn.Linear(d_model, d_tensor)
        self.value = torch.nn.Linear(d_model, d_tensor)

        self.dropout = torch.nn.Dropout(dropout)
    
    def forward(self, q, k, v, mask=None):

        # q, k, v = (batch_size, seq_len, d_model)

        q, k = self.query(k), self.key(q)

        # q, k = (batch_size, seq_len, d_tensor)
        # kT = (batch_size, d_tensor, seq_len)

        wei = q @ k.transpose(-2, -1) * (self.d_tensor ** (-0.5)) # q*kT/sqrt(d_k)

        # wei = (batch_size, seq_len, seq_len)

        if mask is not None:
            wei.masked_fill(mask==0, -1e10)
        
        wei = torch.nn.functional.softmax(wei, dim=-1)
        v = self.value(v)

        # wei = (batch_size, seq_len, seq_len)
        # v = (batch_size, seq_len, d_tensor)

        out = wei @ v

        # out = (batch_size, seq_len, d_tensor): d_tensor * n_heads = d_model

        return out

In [None]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_head: int, dropout: float, num_gpus: int):
        super().__init__()

        assert d_model % d_head == 0
        assert n_heads % num_gpus == 0
        d_tensor = d_model // d_head
        self.d_tensor = d_tensor

        self.heads = torch.nn.ModuleList([
            torch.nn.ModuleList([Head(d_model=d_model, d_head=d_head, dropout=dropout, device_num=num) for _ in range(n_heads // num_gpus)]) for num in range(num_gpus)
        ])
        self.linear = torch.nn.Linear(n_heads * d_tensor, d_model) # n_heads * d_tensor == d_model
        self.dropout = torch.nn.Dropout(dropout)
    
    def forward(self, q, k, v, src_mask=None):
        out = torch.cat([
            head(q, k, v, src_mask) for head in self.heads
        ])

In [6]:
num_gpus = 4