In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data.dataloader import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from torch import tril, ones, softmax, long, arange, multinomial, topk, no_grad
from accelerate import Accelerator
from torch import nn
import re
import math

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.register_buffer(
            "bias",
            tril(ones(config.block_size, config.block_size)).view(
                1, 1, config.block_size, config.block_size
            ),
        )

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
        att = softmax(att, dim=-1)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        return y


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu = nn.GELU(approximate="tanh")
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x


class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class GPT(nn.Module):
    def __init__(
        self,
        block_size=1024,
        vocab_size=50257,
        n_layer=12,
        n_head=12,
        n_embd=768,
    ):
        super().__init__()
        self.block_size = block_size
        self.vocab_size = vocab_size
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_embd = n_embd
        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(self.vocab_size, self.n_embd),
                wpe=nn.Embedding(self.block_size, self.n_embd),
                h=nn.ModuleList([Block(self) for _ in range(self.n_layer)]),
                ln_f=nn.LayerNorm(self.n_embd),
            )
        )
        self.lm_head = nn.Linear(self.n_embd, self.vocab_size, bias=False)

    def forward(self, idx):
        B, T = idx.size()
        assert T <= self.block_size
        pos = arange(0, T, dtype=long, device=idx.device)
        pos_emb = self.transformer.wpe(pos)
        tok_emb = self.transformer.wte(idx)
        x = tok_emb + pos_emb
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        return logits

In [None]:
def evaluate(model, eval_dataloader, loss_fn, accelerator):
    model.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        batch = {k: v.to("cuda") for k, v in batch.items()}
        with no_grad():
            logits = model(batch["input_ids"])[..., :-1, :].contiguous()
            loss = loss_fn(
                logits.view(-1, logits.size(-1)),
                batch["input_ids"][..., 1:].contiguous().view(-1),
            )
        losses.append(accelerator.gather(loss))

    model.train()
    return sum(losses) / len(losses)

In [None]:
def parse_text(example, tokenizer):
    posts = []
    for post in example["posts"]:
        post = post["content"]
        post = re.sub(">>[0-9]{9} {0,1}", "", post)
        post = post.strip()

        if len(post) != 0:
            posts.append(post)

    return {"posts": tokenizer.bos_token.join(posts)}


def filter_board(example):
    return example["board"] == "pol"


def tokenize(examples, tokenizer, context_length):
    tokens = tokenizer(
        examples["posts"],
        max_length=context_length,
        truncation=True,
        return_overflowing_tokens=True,
        return_length=True,
    )

    input_ids = []
    for i, length in enumerate(tokens["length"]):
        if length != context_length:
            continue
        token = tokens["input_ids"][i]
        input_ids.append(token)

    return {"input_ids": input_ids}


def get_dataset(tokenizer):
    dataset = load_dataset("Fal7acy/4chan-archive")["train"]
    dataset = dataset.filter(filter_board)
    dataset = dataset.map(lambda x: parse_text(x, tokenizer))
    dataset = dataset.remove_columns(["board", "thread"])
    dataset.set_format("torch")

    return dataset


def get_dataloader(dataset, bs, tokenizer, context_length):
    dataset = dataset.map(lambda x: tokenize(x, tokenizer, context_length), batched=True, remove_columns=dataset.column_names)
    dataset = dataset.train_test_split()

    train_dataloader = DataLoader(dataset["train"], batch_size=bs, shuffle=True)
    eval_dataloader = DataLoader(dataset["test"], batch_size=bs)

    return train_dataloader, eval_dataloader

In [None]:
def get_grouped_params(model, no_decay=["bias", "LayerNorm.weight"]):
    params_with_wd, params_without_wd = [], []
    for n, p in model.named_parameters():
        if any(nd in n for nd in no_decay):
            params_without_wd.append(p)
        else:
            params_with_wd.append(p)
    return [
        {"params": params_with_wd, "weight_decay": 0.1},
        {"params": params_without_wd, "weight_decay": 0.0},
    ]

In [None]:
def predict(text, tokenizer, model):
    while True:
        toks = tokenizer(text, return_tensors="pt", truncation=True)
        toks.to("cuda")

        with no_grad():
            try:
                logits = model(toks["input_ids"])[-1][-1]
            except AssertionError:
                break

        probs = softmax(logits, dim=-1)
        probs, inds = topk(probs, 50)
        selection = inds[multinomial(probs, num_samples=1).item()].item()

        if selection == tokenizer.eos_token_id:
            break

        text += tokenizer.decode(selection)

    return text

In [None]:
def train(
    epochs,
    train_dataloader,
    model,
    loss_fn,
    acc,
    accelerator,
    optimizer,
    lr_scheduler,
    steps,
):
    for epoch in range(epochs):
        print(f"0/{steps}")
        for step, batch in enumerate(train_dataloader, start=1):
            logits = model(batch["input_ids"])[..., :-1, :].contiguous()
            loss = (
                loss_fn(
                    logits.view(-1, logits.size(-1)),
                    batch["input_ids"][..., 1:].contiguous().view(-1),
                )
                / acc
            )

            accelerator.backward(loss)

            if step % acc == 0:
                accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            if step % 100 == 0:
                print(f"{step}/{steps}")

In [None]:
context_length = 128
epochs = 1
acc = 16
bs = 16
lr = 5e-4
warmup = 1000
n = 100

In [None]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")

model = GPT(block_size=context_length, vocab_size=len(tokenizer))

dataset = get_dataset(tokenizer).select(range(n))
train_dataloader, eval_dataloader = get_dataloader(dataset, bs, tokenizer, context_length)

optimizer = AdamW(get_grouped_params(model), lr=lr)

accelerator = Accelerator(mixed_precision="fp16")
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

steps = epochs * len(train_dataloader)

lr_scheduler1 = LinearLR(optimizer, start_factor=lr, total_iters=warmup)
lr_scheduler2 = CosineAnnealingLR(optimizer, T_max=steps - warmup)
lr_scheduler = SequentialLR(
    optimizer,
    schedulers=[lr_scheduler1, lr_scheduler2],
    milestones=[warmup],
)

loss_fn = nn.CrossEntropyLoss(reduction="mean")

In [None]:
train(
    epochs,
    train_dataloader,
    model,
    loss_fn,
    acc,
    accelerator,
    optimizer,
    lr_scheduler,
    steps,
)

In [None]:
evaluate(model,eval_dataloader,loss_fn,accelerator)

In [None]:
predict("There", tokenizer, model)