In [None]:
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig, pipeline
from torch.nn import CrossEntropyLoss
from torch.utils.data.dataloader import DataLoader
from torch.optim import AdamW, lr_scheduler
from accelerate import Accelerator
from typing import defaultdict
import torch
import tqdm
import re
import pandas as pd
from unidecode import unidecode
import string

In [None]:
class Tokenizer:
    def _get_most_freq_pair(self, token_freqs):
        pair_freqs = defaultdict(int)

        for token, freq in token_freqs.items():
            for i in range(len(token) - 1):
                pair = token[i : i + 2]
                pair_freqs[pair] += freq

        return max(pair_freqs.items(), key=lambda x: x[1])[0]

    def _merge_tokens(self, token_freqs, merge_pair):
        token_freqs_new = {}

        for token, freq in token_freqs.items():
            token_new = []
            i = 0

            while i < len(token):
                pair = token[i : i + 2]

                if pair == merge_pair:
                    token_new.append("".join(pair))
                    i += 2
                else:
                    token_new.append(token[i])
                    i += 1

            token_new = tuple(token_new)
            token_freqs_new[token_new] = freq

        return token_freqs_new

    def _get_word_freqs(self, texts):
        word_freqs = defaultdict(int)

        for text in texts:
            for word in text.split():
                word_freqs[word] += 1

        return word_freqs

    def _get_token_freqs(self, word_freqs):
        token_freqs = {}

        for word, freq in word_freqs.items():
            token = tuple(word)
            token_freqs[token] = freq

        return token_freqs

    def _pre_tokenize(self, text):
        return unidecode(text).lower()

    def _build_vocab(self, rules, pad_token, unk_token):
        vocab = {pad_token, unk_token}
        vocab |= set(self._pre_tokenize(string.printable))
        vocab |= {"".join(token) for token in rules}

        return vocab

    def _train(self, rules_size, token_freqs):
        rules = set()

        while len(rules) < rules_size:
            pair = self._get_most_freq_pair(token_freqs)
            rules.add(pair)
            token_freqs = self._merge_tokens(token_freqs, pair)

        return rules

    def _tokenize(self, texts):
        tokens = [char for text in texts for char in self._pre_tokenize(text)]

        for merge_pair in self.rules:
            tokens_new = []
            i = 0

            while i < len(tokens):
                pair = tokens[i : i + 2]

                if pair == merge_pair:
                    tokens_new.append("".join(pair))
                    i += 2
                else:
                    tokens_new.append(tokens[i])
                    i += 1

            tokens = tokens_new

        return tokens

    def _numericalize(self, tokens):
        return [
            self.token2id[token] if token in self.token2id else self.unk_token_id
            for token in tokens
        ]

    def _chunk(self, ids):
        return [ids[i : i + self.max_len] for i in range(0, len(ids), self.max_len)]

    def _pad(self, ids):
        padding_amount = self.max_len - len(ids[-1])
        ids[-1] += [self.pad_token_id] * padding_amount

        return ids, padding_amount

    def _get_mask(self, ids, padding_amount):
        a = [[1] * self.max_len for _ in range(len(ids) - 1)]
        b = [[1] * (self.max_len - padding_amount) + [0] * padding_amount]

        return a + b

    def __init__(self, texts, rules_size, max_len):
        pad_token = "<|pad|>"
        unk_token = "<|unk|>"

        texts = [self._pre_tokenize(text) for text in texts]
        word_freqs = self._get_word_freqs(texts)
        token_freqs = self._get_token_freqs(word_freqs)
        rules = self._train(rules_size, token_freqs)
        vocab = self._build_vocab(rules, pad_token, unk_token)

        self.pad_token = pad_token
        self.unk_token = unk_token
        self.max_len = max_len
        self.rules = rules
        self.token2id = {token: i for i, token in enumerate(vocab)}
        self.id2token = {i: token for i, token in enumerate(vocab)}
        self.pad_token_id = self.token2id[self.pad_token]
        self.unk_token_id = self.token2id[self.unk_token]

    def __call__(self, texts):
        tokens = self._tokenize(texts)
        ids = self._numericalize(tokens)
        ids = self._chunk(ids)
        ids, padding_amount = self._pad(ids)
        mask = self._get_mask(ids, padding_amount)

        return {
            "input_ids": ids,
            "attention_mask": mask,
        }

    def __len__(self):
        return len(self.token2id)

    def decode(self, inputs):
        output = []
        
        if isinstance(inputs, dict):
            inputs = [input_id for input_ids in inputs["input_ids"] for input_id in input_ids]

        for input_id in inputs:
            if input_id == self.pad_token_id:
                continue
            if input_id == self.unk_token_id:
                token = "?"
            else:
                token = self.id2token[input_id]

            output.append(token)

        return "".join(output)

In [None]:
NNN = 1000
epochsN = 3

dataset = DatasetDict({
    "train": load_dataset("huggingface-course/codeparrot-ds-train")["train"].shuffle(seed=42).select(range(NNN)),
    "valid": load_dataset("huggingface-course/codeparrot-ds-train")["train"].shuffle(seed=42).select(range(NNN)),
})

context_length = 128
corpset = list(dataset["valid"]["content"])

vs = 3000
tokenizer = Tokenizer(corpset,vs,context_length)

In [None]:
%%time

dataset_tokenized = dataset.map(
    lambda x: tokenizer(x["content"]),
    batched=True,
    remove_columns=dataset["train"].column_names,
    load_from_cache_file=False
)

In [None]:
config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer),
    n_ctx=context_length,
)

In [None]:
bs = 8

dataset_tokenized.set_format("torch")
train_dataloader = DataLoader(dataset_tokenized["train"], batch_size=bs, shuffle=True)
eval_dataloader = DataLoader(dataset_tokenized["valid"], batch_size=bs)

In [None]:
weight_decay = 0.1

def get_grouped_params(model, no_decay=["bias", "LayerNorm.weight"]):
    params_with_wd, params_without_wd = [], []
    for n, p in model.named_parameters():
        if any(nd in n for nd in no_decay):
            params_without_wd.append(p)
        else:
            params_with_wd.append(p)
    return [
        {"params": params_with_wd, "weight_decay": weight_decay},
        {"params": params_without_wd, "weight_decay": 0.0},
    ]

In [None]:
def evaluate():
    model.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(batch["input_ids"], labels=batch["input_ids"])

        losses.append(accelerator.gather_for_metrics(outputs.loss))
        
    loss = torch.mean(torch.tensor(losses))
    
    try:
        perplexity = torch.exp(loss)
    except OverflowError:
        perplexity = float("inf")
    return loss.item(), perplexity.item()

In [None]:
model = GPT2LMHeadModel(config)

In [None]:
optimizer = AdamW(get_grouped_params(model), lr=5e-4)

In [None]:
accelerator = Accelerator(mixed_precision="fp16")

model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

In [None]:
num_train_epochs = epochsN
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch
gradient_accumulation_steps = 8
eval_steps = 50
completed_steps = 0

In [None]:
lr_sched = lr_scheduler.LinearLR(optimizer, total_iters=num_training_steps)

In [None]:
model.train()
loss_fct = CrossEntropyLoss()
progress = tqdm.tqdm(total=num_training_steps)

for epoch in range(num_train_epochs):
    for step, batch in enumerate(train_dataloader, start=1):        
        logits = model(batch["input_ids"]).logits
        
        shift_labels = batch["input_ids"][..., 1:].contiguous()
        shift_logits = logits[..., :-1, :].contiguous()
    
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        
        loss = loss / gradient_accumulation_steps
        """
        if step % 100 == 0:
            accelerator.print(
                {
                    "loss/train": loss.item() * gradient_accumulation_steps,
                }
            )
        """

    
        accelerator.backward(loss)

        del logits
        del shift_logits
        del shift_labels     

        
        if step % gradient_accumulation_steps == 0:
            accelerator.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_sched.step()
            optimizer.zero_grad()
            completed_steps += 1

        if (step % (eval_steps * gradient_accumulation_steps)) == 0:
            eval_loss, perplexity = evaluate()
            accelerator.print({"loss/eval": eval_loss, "perplexity": perplexity})
            model.train()
            accelerator.wait_for_everyone()

            
        progress.update(1)

In [None]:
text = """x ="""
newt = []

for _ in range(20):
    toks = tokenizer(text)
    toks = {a:torch.tensor(b).to("cuda") for a,b in toks.items()}
    with torch.no_grad():
        logits = model(**toks).logits
    output = torch.argmax(logits, dim=-1)[0][0].item()
    newt.append(output)

text += tokenizer.decode(newt)
text