In [None]:
from urllib import request
from torch.amp import GradScaler
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from torch.nn import (
    Linear,
    Embedding,
    ModuleList,
    Module,
    LayerNorm,
    ReLU,
    CrossEntropyLoss,
)
import torch
import math
import re
import pandas
import tarfile
import tiktoken
import html
import itertools

In [None]:
class MSA(Module):
    def __init__(self, n_ctx, n_embd, n_head):
        super().__init__()

        self.n_embd = n_embd
        self.n_head = n_head
        self.bias = torch.tril(torch.ones(n_ctx, n_ctx, device="cuda")).view(
            1, 1, n_ctx, n_ctx
        )
        self.c_attn = Linear(n_embd, 3 * n_embd)
        self.c_proj = Linear(n_embd, n_embd)

    def forward(self, x, mask):
        B, T, C = x.size()
        h = self.n_head
        C_h = C // h

        qkv = self.c_attn(x)
        q, k, v = qkv.split(C, -1)

        k = k.view(B, T, h, C_h).transpose(1, 2)
        q = q.view(B, T, h, C_h).transpose(1, 2)
        v = v.view(B, T, h, C_h).transpose(1, 2)

        att = (q @ k.transpose(2, 3)) / math.sqrt(C_h)
        att = att.masked_fill(self.bias == 0, float("-inf"))
        att = att.masked_fill(mask.view(B, 1, 1, T) == 0, float("-inf"))
        att = torch.softmax(att, -1)

        x = att @ v
        x = x.transpose(1, 2).contiguous().view(B, T, C)
        x = self.c_proj(x)

        return x


class MLP(Module):
    def __init__(self, n_embd):
        super().__init__()

        self.c_fc = Linear(n_embd, 4 * n_embd)
        self.relu = ReLU()
        self.c_proj = Linear(4 * n_embd, n_embd)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.relu(x)
        x = self.c_proj(x)

        return x


class Decoder(Module):
    def __init__(self, n_ctx, n_embd, n_head):
        super().__init__()

        self.ln_1 = LayerNorm(n_embd)
        self.attn = MSA(n_ctx, n_embd, n_head)
        self.ln_2 = LayerNorm(n_embd)
        self.mlp = MLP(n_embd)

    def forward(self, x, mask):
        x = x + self.attn(self.ln_1(x), mask)
        x = x + self.mlp(self.ln_2(x))

        return x


class Model(Module):
    def __init__(self, n_ctx=1024, n_embd=768, n_head=12, n_layer=12, vocab_size=50257):
        super().__init__()

        self.pos = torch.arange(n_ctx, device="cuda")
        self.wte = Embedding(vocab_size, n_embd)
        self.wpe = Embedding(n_ctx, n_embd)
        self.h = ModuleList(Decoder(n_ctx, n_embd, n_head) for _ in range(n_layer))
        self.ln_f = LayerNorm(n_embd)
        self.lm_head = Linear(n_embd, vocab_size)

    def forward(self, x, mask):
        x = self.wte(x) + self.wpe(self.pos)

        for decoder in self.h:
            x = decoder(x, mask)

        x = self.ln_f(x)
        x = self.lm_head(x)

        return x

In [None]:
def get_parameters(model):
    parameters = [
        {"params": [], "weight_decay": WEIGHT_DECAY},
        {"params": [], "weight_decay": 0.0},
    ]

    for name, parameter in model.named_parameters():
        if any(i in name for i in NO_WEIGHT_DECAY):
            parameters[1]["params"].append(parameter)
        else:
            parameters[0]["params"].append(parameter)

    return parameters

In [None]:
def predict(text, model, tokenizer):
    while True:
        tokens_and_mask = tokenize(text, tokenizer, False)
        tokens_and_mask = torch.tensor(tokens_and_mask, device="cuda")

        if tokens_and_mask.size(0) != 1:
            break

        tokens = tokens_and_mask[0, 0, :].view(1, -1)
        mask = tokens_and_mask[0, 1, :].view(1, -1)

        with torch.autocast("cuda"), torch.no_grad():
            logits = model(tokens, mask)[0, -1, :]

        probabilities = torch.softmax(logits, 0)
        probabilities, indices = torch.topk(probabilities, 50)
        index = torch.multinomial(probabilities, 1)
        token = indices[index]

        if token == tokenizer.eot_token:
            break

        text += tokenizer.decode(token.tolist())

    return text

In [None]:
def get_loss(batch, loss_function, calculate_gradient):
    references = batch["tokens"][:, 1:]
    references = references.contiguous().view(-1)

    if calculate_gradient:
        with torch.autocast("cuda"):
            predictions = model(batch["tokens"], batch["mask"])
    else:
        with torch.autocast("cuda"), torch.no_grad():
            predictions = model(batch["tokens"], batch["mask"])

    predictions = predictions[:, :-1, :]
    predictions = predictions.contiguous().view(-1, predictions.size(2))

    return loss_function(predictions, references)


def evaluate_model(model, data_loader, loss_function):
    model.eval()
    data_loader = itertools.islice(data_loader, EVALUATE_ITEMS // BATCH_SIZE)
    loss = [get_loss(batch, loss_function, False) for batch in data_loader]
    loss = sum(loss) / len(loss)
    model.train()

    return loss.item()


def train_model(model, train, valid, loss_function, optimizer, scheduler, scaler):
    def print_information(step):
        steps = len(train)
        train_loss = evaluate_model(model, train, loss_function)
        valid_loss = evaluate_model(model, valid, loss_function)

        a = f"Progress:   {step}/{steps}"
        b = f"Train loss: {train_loss}"
        c = f"Valid loss: {valid_loss}"

        print(a, b, c, sep="\n")

    for step, batch in enumerate(train):
        loss = get_loss(batch, loss_function, True) / ACCUMULATION_STEPS
        scaler.scale(loss).backward()

        if (step + 1) % LOGGING_STEPS == 0:
            print_information(step)

        # https://discuss.pytorch.org/t/userwarning-detected-call-of-lr-scheduler-step-before-optimizer-step
        if (step + 1) % ACCUMULATION_STEPS == 0:
            scaler.step(optimizer)
            scaler_old = scaler.get_scale()
            scaler.update()
            scaler_new = scaler.get_scale()
            if scaler_old <= scaler_new:
                scheduler.step()
            optimizer.zero_grad()

In [None]:
def download_files():
    request.urlretrieve(
        "https://zenodo.org/records/3606810/files/pol_0616-1119_labeled.tar.zst",
        "pol_0616-1119_labeled.tar.zst",
    )

    file = tarfile.open("pol_0616-1119_labeled.tar.zst")
    file.extractall()
    file.close()


def parse_post(post):
    post = html.unescape(post)
    post = re.sub(r"<br>", "\n", post)

    while re.search(r"<[^<>]*>", post):
        post = re.sub(r"<[^<>]*>", "", post)

    return post


def collate(batch):
    tokens = torch.tensor([i[0] for i in batch], device="cuda")
    mask = torch.tensor([i[1] for i in batch], device="cuda")

    return {"tokens": tokens, "mask": mask}


def get_data_loaders(tokenizer):
    dataset = pandas.read_json(
        "pol_062016-112019_labeled.ndjson", lines=True, nrows=DATASET_ITEMS
    )

    posts = [post for posts in dataset["posts"] for post in posts]
    posts = [post["com"] for post in posts if "com" in post]
    posts = [parse_post(post) for post in posts]
    posts = [post for post in posts if post]
    posts = [tokens for post in posts for tokens in tokenize(post, tokenizer, True)]

    train = posts[:-EVALUATE_ITEMS]
    valid = posts[-EVALUATE_ITEMS:]

    train = DataLoader(train, batch_size=BATCH_SIZE, collate_fn=collate, shuffle=True)
    valid = DataLoader(valid, batch_size=BATCH_SIZE, collate_fn=collate, shuffle=False)

    return train, valid

In [None]:
def get_scheduler(data_loader, optimizer):
    linear = LinearLR(optimizer, start_factor=LEARNING_RATE, total_iters=WARMUP_STEPS)
    cosine = CosineAnnealingLR(optimizer, T_max=len(data_loader) - WARMUP_STEPS)

    return SequentialLR(
        optimizer=optimizer,
        schedulers=[linear, cosine],
        milestones=[WARMUP_STEPS],
    )

In [None]:
def tokenize(text, tokenizer, add_eot):
    padded_tokens = []

    tokens = tokenizer.encode(text)
    if add_eot:
        tokens += [tokenizer.eot_token]

    tokens_list = [
        tokens[i : i + CONTEXT_LENGTH] for i in range(0, len(tokens), CONTEXT_LENGTH)
    ]

    for tokens in tokens_list:
        mask = [1] * len(tokens)
        if len(tokens) != CONTEXT_LENGTH:
            i = CONTEXT_LENGTH - len(tokens)
            tokens += [PADDING_TOKEN] * i
            mask += [0] * i

        padded_tokens.append((tokens, mask))

    return padded_tokens

In [None]:
ACCUMULATION_STEPS = 8
BATCH_SIZE = 4
CONTEXT_LENGTH = 512
DATASET_ITEMS = 4000
EVALUATE_ITEMS = 64
LEARNING_RATE = 6e-4
LOGGING_STEPS = 64
NO_WEIGHT_DECAY = ["wpe.weight", "bias", "ln"]
PADDING_TOKEN = 43000
WARMUP_STEPS = 2000
WEIGHT_DECAY = 0.1

In [None]:
tokenizer = tiktoken.encoding_for_model("gpt2")
train, valid = get_data_loaders(tokenizer)
model = Model(n_ctx=CONTEXT_LENGTH, vocab_size=tokenizer.n_vocab).to("cuda")
optimizer = AdamW(params=get_parameters(model), lr=LEARNING_RATE)
scheduler = get_scheduler(train, optimizer)
loss_function = CrossEntropyLoss(reduction="mean", ignore_index=PADDING_TOKEN)
scaler = GradScaler()

In [None]:
%%time
train_model(model, train, valid, loss_function, optimizer, scheduler, scaler)

In [None]:
predict("There", model, tokenizer)