In [None]:
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from torch.nn import (
    Linear,
    Embedding,
    ModuleList,
    Module,
    LayerNorm,
    GELU,
    ModuleDict,
    CrossEntropyLoss,
)
from torch import tril, ones, softmax, long, arange, multinomial, topk, no_grad, tensor
from accelerate import Accelerator
from transformers import AutoTokenizer
from re import sub
from math import sqrt
from requests import get
from pandas import read_parquet
from os import mkdir
from os.path import isdir

In [None]:
class CausalSelfAttention(Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = Linear(config.n_embd, config.n_embd)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.register_buffer(
            "bias",
            tril(ones(config.block_size, config.block_size)).view(
                1, 1, config.block_size, config.block_size
            ),
        )

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) * (1.0 / sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
        att = softmax(att, dim=-1)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        return y


class MLP(Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = Linear(config.n_embd, 4 * config.n_embd)
        self.gelu = GELU(approximate="tanh")
        self.c_proj = Linear(4 * config.n_embd, config.n_embd)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x


class Block(Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class GPT(Module):
    def __init__(
        self,
        block_size=1024,
        vocab_size=50257,
        n_layer=12,
        n_head=12,
        n_embd=768,
    ):
        super().__init__()
        self.block_size = block_size
        self.vocab_size = vocab_size
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_embd = n_embd
        self.transformer = ModuleDict(
            dict(
                wte=Embedding(self.vocab_size, self.n_embd),
                wpe=Embedding(self.block_size, self.n_embd),
                h=ModuleList([Block(self) for _ in range(self.n_layer)]),
                ln_f=LayerNorm(self.n_embd),
            )
        )
        self.lm_head = Linear(self.n_embd, self.vocab_size, bias=False)

    def forward(self, idx):
        B, T = idx.size()
        assert T <= self.block_size
        pos = arange(0, T, dtype=long, device=idx.device)
        pos_emb = self.transformer.wpe(pos)
        tok_emb = self.transformer.wte(idx)
        x = tok_emb + pos_emb
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        return logits

In [None]:
def evaluate(model, eval_dataloader, loss_fn, accelerator):
    model.eval()
    losses = []

    for batch in eval_dataloader:
        references = batch[..., 1:].contiguous().view(-1)
        with no_grad():
            predictions = model(batch)[..., :-1, :].contiguous()
            predictions = predictions.view(-1, predictions.size(-1))

        loss = loss_fn(predictions, references)
        loss = accelerator.gather(loss)
        losses.append(loss)

    model.train()
    loss = sum(losses) / len(losses)
    return loss.item()

In [None]:
def get_grouped_params(model, no_decay, wd):
    params_with_wd, params_without_wd = [], []
    for n, p in model.named_parameters():
        if any(nd in n for nd in no_decay):
            params_without_wd.append(p)
        else:
            params_with_wd.append(p)
    return [
        {"params": params_with_wd, "weight_decay": wd},
        {"params": params_without_wd, "weight_decay": 0.0},
    ]

In [None]:
def predict(text, tokenizer, model):
    while True:
        toks = tokenizer(text, return_tensors="pt", truncation=True)["input_ids"].to("cuda")

        try:
            with no_grad():
                logits = model(toks)[-1][-1]
        except AssertionError:
            break

        probs = softmax(logits, dim=-1)
        probs, inds = topk(probs, 50)
        selection = inds[multinomial(probs, num_samples=1).item()].item()

        if selection == tokenizer.eos_token_id:
            break

        text += tokenizer.decode(selection)

    return text

In [None]:
def eval(step, steps, model, eval_dataloader, train_dataloader, loss_fn, accelerator):
    t_loss = evaluate(model, train_dataloader, loss_fn, accelerator)
    e_loss = evaluate(model, eval_dataloader, loss_fn, accelerator)

    s1 = f"Progress:        {step}/{steps}"
    s2 = f"Training loss:   {t_loss}"
    s3 = f"Validation loss: {e_loss}"

    print(s1, s2, s3, sep="\n")


def train(
    epochs,
    train_dataloader,
    eval_dataloader,
    model,
    loss_fn,
    acc,
    accelerator,
    optimizer,
    lr_scheduler,
    info_n=400,
):
    steps = epochs * len(train_dataloader)
    eval(0, steps, model, eval_dataloader, train_dataloader, loss_fn, accelerator)

    for epoch in range(epochs):
        for step, batch in enumerate(train_dataloader, start=1):
            predictions = model(batch)[..., :-1, :].contiguous()
            predictions = predictions.view(-1, predictions.size(-1))
            references = batch[..., 1:].contiguous().view(-1)

            loss = loss_fn(predictions, references) / acc
            accelerator.backward(loss)

            if step % acc == 0:
                accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            if step % info_n == 0:
                eval(
                    step,
                    steps,
                    model,
                    eval_dataloader,
                    train_dataloader,
                    loss_fn,
                    accelerator,
                )

    eval(steps, steps, model, eval_dataloader, train_dataloader, loss_fn, accelerator)

In [None]:
class PandasDataset(Dataset):
    def __init__(self, dataframe):
        self.df = dataframe

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        return tensor(self.df.iloc[idx]["posts"])


def tokenize(example):
    return tokenizer(
        example,
        max_length=context_length,
        truncation=True,
        return_overflowing_tokens=True,
    )["input_ids"][:-1]


def parse_posts(posts):
    posts = [sub(">>[0-9]{9} {0,1}", "", post["content"]).strip() for post in posts]
    return tokenizer.bos_token.join(posts)


def download_files():
    if isdir("datasets"):
        return

    mkdir("datasets")
    urls = [
        f"https://huggingface.co/datasets/Fal7acy/4chan-archive/resolve/main/data/train-0000{i}-of-00002.parquet?download=true"
        for i in range(2)
    ]
    files = [f"datasets/{i}.parquet" for i in range(len(urls))]

    for i, url in enumerate(urls):
        response = get(url, stream=True)
        with open(files[i], "wb") as handle:
            for data in response.iter_content(chunk_size=1024):
                handle.write(data)


def get_dataloader(p=None, split=0.8):
    df = read_parquet("datasets", engine="pyarrow")
    df = df[df["board"] == "pol"]
    df = df.reset_index(drop=True)
    df["posts"] = df["posts"].apply(parse_posts)
    df = df.drop(["board", "thread"], axis=1)
    df["posts"] = df["posts"].apply(tokenize)
    df = df.explode("posts").reset_index(drop=True)
    df = df.dropna()

    if p != None:
        cutoff = int(len(df) * p)
        df = df.iloc[:cutoff]

    cutoff = int(len(df) * split)
    train = df.iloc[:cutoff]
    valid = df.iloc[cutoff:]

    train = PandasDataset(train)
    valid = PandasDataset(valid)

    train = DataLoader(train, batch_size=bs, shuffle=True)
    valid = DataLoader(valid, batch_size=bs, shuffle=False)

    return train, valid

In [None]:
context_length = 128
epochs = 1
acc = 16
bs = 16
lr = 5e-4
warmup_steps = 1000
wd = 0.1
p = 0.05
split = 0.8
info_n = 400
no_decay = ["bias", "LayerNorm.weight"]

In [None]:
download_files()

In [None]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")

train_dl, valid_dl = get_dataloader(p, split)

model = GPT(block_size=context_length, vocab_size=len(tokenizer))

params = get_grouped_params(model=model, no_decay=no_decay, wd=wd)
optimizer = AdamW(params=params, lr=lr)

accelerator = Accelerator(mixed_precision="fp16")
model, optimizer, train_dl, valid_dl = accelerator.prepare(
    model, optimizer, train_dl, valid_dl
)

warmup = LinearLR(optimizer, start_factor=lr, total_iters=warmup_steps)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs * len(train_dl) - warmup_steps)
lr_scheduler = SequentialLR(
    optimizer,
    schedulers=[warmup, scheduler],
    milestones=[warmup_steps],
)

loss_fn = CrossEntropyLoss(reduction="mean")

In [None]:
train(epochs, train_dl, valid_dl, model, loss_fn, acc, accelerator, optimizer, lr_scheduler, info_n)

In [None]:
predict("There", tokenizer, model)