In [None]:
from io import BufferedReader
from zstandard import ZstdDecompressor
from urllib import request
from datetime import datetime
from torch.amp import GradScaler
from torch.utils.data import DataLoader, IterableDataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from torch.nn import (
    Linear,
    Embedding,
    ModuleList,
    Module,
    LayerNorm,
    GELU,
    CrossEntropyLoss,
    functional,
    utils,
)
import torch
import re
import pandas
import tarfile
import tiktoken
import html
import itertools
import json

In [None]:
# https://medium.com/@saeed.mehrang/understanding-grouped-query-attention-a-practical-guide-with-pytorch-implementation-9e3f9f26bb79
class GQA(Module):
    def __init__(self):
        super().__init__()

        self.q_proj = Linear(N_EMBD, N_EMBD)
        self.k_proj = Linear(N_EMBD, N_EMBD * N_KVHEAD // N_QHEAD)
        self.v_proj = Linear(N_EMBD, N_EMBD * N_KVHEAD // N_QHEAD)
        self.o_proj = Linear(N_EMBD, N_EMBD)

    def forward(self, x, mask, is_causal):
        batch_size = x.size(0)

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = q.view(batch_size, CONTEXT_SIZE, N_QHEAD, N_EMBD // N_QHEAD)
        k = k.view(batch_size, CONTEXT_SIZE, N_KVHEAD, N_EMBD // N_QHEAD)
        v = v.view(batch_size, CONTEXT_SIZE, N_KVHEAD, N_EMBD // N_QHEAD)

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        mask = mask.view(batch_size, 1, 1, CONTEXT_SIZE).expand(
            batch_size, 1, CONTEXT_SIZE, CONTEXT_SIZE
        )

        if is_causal:
            mask = mask & torch.ones(
                batch_size,
                1,
                CONTEXT_SIZE,
                CONTEXT_SIZE,
                dtype=torch.bool,
                device="cuda",
            ).tril(0)

        x = functional.scaled_dot_product_attention(
            q, k, v, dropout_p=DROPOUT, attn_mask=mask, enable_gqa=True
        )
        x = x.transpose(1, 2).contiguous().view(batch_size, CONTEXT_SIZE, N_EMBD)
        x = self.o_proj(x)

        return x


class MLP(Module):
    def __init__(self):
        super().__init__()

        self.c_fc = Linear(N_EMBD, 4 * N_EMBD)
        self.gelu = GELU("tanh")
        self.c_proj = Linear(4 * N_EMBD, N_EMBD)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)

        return x


class Decoder(Module):
    def __init__(self):
        super().__init__()

        self.ln_1 = LayerNorm(N_EMBD)
        self.attn = GQA()
        self.ln_2 = LayerNorm(N_EMBD)
        self.mlp = MLP()

    def forward(self, x, mask, is_causal):
        x = x + self.attn(self.ln_1(x), mask, is_causal)
        x = x + self.mlp(self.ln_2(x))

        return x


class Model(Module):
    def __init__(self, vocab_size):
        super().__init__()

        self.pos = torch.arange(CONTEXT_SIZE, device="cuda")
        self.wte = Embedding(vocab_size, N_EMBD)
        self.wpe = Embedding(CONTEXT_SIZE, N_EMBD)
        self.h = ModuleList(Decoder() for _ in range(N_LAYER))
        self.ln_f = LayerNorm(N_EMBD)
        self.lm_head = Linear(N_EMBD, vocab_size)

    def forward(self, x, mask, is_causal):
        x = self.wte(x) + self.wpe(self.pos)

        for decoder in self.h:
            x = decoder(x, mask, is_causal)

        x = self.ln_f(x)
        x = self.lm_head(x)

        return x

In [None]:
def get_parameters(model):
    parameters = [
        {"params": [], "weight_decay": WEIGHT_DECAY},
        {"params": [], "weight_decay": 0.0},
    ]

    for name, parameter in model.named_parameters():
        if any(i in name for i in ["wpe.weight", "bias", "ln"]):
            parameters[1]["params"].append(parameter)
        else:
            parameters[0]["params"].append(parameter)

    return parameters

In [None]:
def predict(text, model, tokenizer):
    model.eval()

    while True:
        tokens_and_mask, i = tokenize(text, tokenizer, False)
        tokens_and_mask = torch.tensor(tokens_and_mask, device="cuda")

        if tokens_and_mask.size(0) != 1:
            break

        tokens = tokens_and_mask[:, 0, :]
        mask = tokens_and_mask[:, 1, :].to(torch.bool)

        with torch.autocast("cuda"), torch.no_grad():
            logits = model(tokens, mask, False)[0, i - 1, :]

        probabilities = torch.softmax(logits, 0)
        probabilities, indices = torch.topk(probabilities, TOPK)
        index = torch.multinomial(probabilities, 1)
        token = indices[index]

        if token == tokenizer.eot_token:
            break

        text += tokenizer.decode(token.tolist())

    model.train()
    return text

In [None]:
def get_loss(batch, loss_function):
    references = batch["tokens"][:, 1:]
    references = references.contiguous().view(-1)

    with torch.autocast("cuda"):
        predictions = model(batch["tokens"], batch["mask"], True)

    predictions = predictions[:, :-1, :]
    predictions = predictions.contiguous().view(-1, predictions.size(2))

    return loss_function(predictions, references)


def evaluate_model(model, data_loader, loss_function):
    model.eval()

    data_loader = itertools.islice(data_loader, 256 // BATCH_SIZE)
    with torch.no_grad():
        losses = [get_loss(batch, loss_function) for batch in data_loader]
        loss = sum(losses) / len(losses)

    model.train()
    return loss.item()


def train_model(model, train, valid, loss_function, optimizer, scheduler, scaler):
    def print_information(step):
        progress = step * BATCH_SIZE / DATASET_SIZE * 100
        time = (datetime.now() - time_start).total_seconds() / 60
        train_loss = evaluate_model(model, train, loss_function)
        valid_loss = evaluate_model(model, valid, loss_function)
        print(f"{progress:3.0f}% | {time:4.0f} | {train_loss:5.2f} | {valid_loss:5.2f}")

    print("prog | time | train | valid")
    time_start = datetime.now()

    for step, batch in enumerate(train, 1):
        loss = get_loss(batch, loss_function) / ACCUMULATION_STEPS
        scaler.scale(loss).backward()

        if step % 100 == 0:
            print_information(step)

        if step % ACCUMULATION_STEPS == 0:
            scaler.unscale_(optimizer)
            utils.clip_grad_norm_(model.parameters(), MAX_NORM)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()

In [None]:
def download_dataset():
    request.urlretrieve(
        "https://zenodo.org/records/3606810/files/pol_0616-1119_labeled.tar.zst",
        "dataset_train.tar.zst",
    )


class Dataset(IterableDataset):
    def __init__(self, tokenizer, dataset):
        self.tokenizer = tokenizer
        self.dataset = dataset

    def __iter__(self):
        i = 0
        for post in self.read_posts():
            if "com" not in post:
                continue

            comment = self.parse_comment(post["com"])
            if not comment:
                continue

            for tokens_and_mask in tokenize(comment, self.tokenizer, True)[0]:
                tokens_and_mask = torch.tensor(tokens_and_mask, device="cuda")
                tokens = tokens_and_mask[0, :]
                mask = tokens_and_mask[1, :].to(torch.bool)

                yield {"tokens": tokens, "mask": mask}

                i += 1
                if DATASET_SIZE <= i:
                    return

        print("Ran out of dataset items early, LR scheduling might be unexpected")

    def read_posts(self):
        with open(self.dataset, "rb") as dataset:
            with ZstdDecompressor().stream_reader(dataset) as stream:
                with tarfile.open(fileobj=BufferedReader(stream), mode="r|") as file:
                    for line in file.extractfile(file.firstmember):
                        for post in json.loads(line.decode("utf-8"))["posts"]:
                            yield post

    def parse_comment(self, text):
        text = html.unescape(text)
        text = re.sub(r"<br>", "\n", text)

        while re.search(r"<[^<>]*>", text):
            text = re.sub(r"<[^<>]*>", "", text)

        return text

In [None]:
def get_scheduler(data_loader, optimizer):
    total_steps = DATASET_SIZE // (BATCH_SIZE * ACCUMULATION_STEPS)
    warmup_steps = int(total_steps * WARMUP)
    linear = LinearLR(optimizer, start_factor=LEARNING_RATE, total_iters=warmup_steps)
    cosine = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps)

    return SequentialLR(
        optimizer=optimizer,
        schedulers=[linear, cosine],
        milestones=[warmup_steps],
    )

In [None]:
def tokenize(text, tokenizer, add_eot):
    tokens = tokenizer.encode(text)
    if add_eot:
        tokens += [tokenizer.eot_token]

    padded = []
    length = len(tokens)
    tokens_list = [tokens[i : i + CONTEXT_SIZE] for i in range(0, length, CONTEXT_SIZE)]

    for tokens in tokens_list:
        remaining = CONTEXT_SIZE - len(tokens)
        mask = [1] * len(tokens) + [0] * remaining
        tokens += [43000] * remaining
        padded.append((tokens, mask))

    return padded, length

In [None]:
ACCUMULATION_STEPS = 8
BATCH_SIZE = 16
BETAS = (0.9, 0.999)
CONTEXT_SIZE = 128
DATASET_SIZE = 8192
DROPOUT = 0.1
EPS = 1e-8
LEARNING_RATE = 5e-4
MAX_NORM = 1.0
N_EMBD = 768
N_KVHEAD = 3
N_LAYER = 12
N_QHEAD = 12
TOPK = 50
WARMUP = 0.01
WEIGHT_DECAY = 0.01

In [None]:
# download_dataset()

In [None]:
tokenizer = tiktoken.encoding_for_model("gpt2")
train = DataLoader(Dataset(tokenizer, "dataset_train.tar.zst"), batch_size=BATCH_SIZE)
valid = DataLoader(Dataset(tokenizer, "dataset_valid.tar.zst"), batch_size=BATCH_SIZE)
model = Model(tokenizer.n_vocab).to("cuda")
optimizer = AdamW(params=get_parameters(model), lr=LEARNING_RATE, betas=BETAS, eps=EPS, fused=True)
scheduler = get_scheduler(train, optimizer)
loss_function = CrossEntropyLoss(ignore_index=43000)
scaler = GradScaler()

In [None]:
train_model(model, train, valid, loss_function, optimizer, scheduler, scaler)

In [None]:
prediction = predict("There", model, tokenizer)
print(prediction)