In [None]:
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig, pipeline
from torch.nn import CrossEntropyLoss
from torch.utils.data.dataloader import DataLoader
from torch.optim import AdamW, lr_scheduler
from accelerate import Accelerator
from typing import defaultdict
import torch
import tqdm
import re
import pandas as pd

In [None]:
NNN = 100
epochsN = 3

In [None]:
class Tokenizer:
    def convert_ids_to_tokens(self, ids):
        return self.vocab_id[ids]

    def convert_tokens_to_string(self, tokens):
        return "".join(tokens).replace("Ä ", " ")

    def pair_freqs(self, tokens):
        freqs = defaultdict(int)

        for i in range(len(tokens)):
            for j in range(len(tokens[i][0]) - 1):
                pair = (tokens[i][0][j], tokens[i][0][j + 1])

                freqs[pair] += tokens[i][1]
        return sorted(freqs.items(), key=lambda x: x[1], reverse=True)[0][0]

    def merge_tokens(self, tokens, pair):
        tokens_new = []

        i = 0
        while i < len(tokens):
            w = []

            j = 0
            while j < len(tokens[i][0]) - 1:
                pair_new = (tokens[i][0][j], tokens[i][0][j + 1])

                if pair == pair_new:
                    t = tokens[i][0][j] + tokens[i][0][j + 1]

                    w.append(t)
                    j += 2
                else:
                    w.append(tokens[i][0][j])
                    j += 1

            if j == len(tokens[i][0]) - 1:
                w.append(tokens[i][0][j])

            tokens_new.append((w, tokens[i][1]))
            i += 1

        return tokens_new

    def __len__(self):
        return len(vocab)

    def __init__(self, corpus: list[str], vocab_size, max_len):
        self.pre_tokenizer = AutoTokenizer.from_pretrained("gpt2")._tokenizer.pre_tokenizer

        word_freqs = defaultdict(int)
        tokens = []
        vocab = {"<|pad|>", "<|unk|>"}

        for text in corpus:
            words = self.pre_tokenizer.pre_tokenize_str(text)
            words = [word for word, offset in words]

            for word in words:
                word_freqs[word] += 1

                for char in word:
                    vocab.add(char)
                    
        vocab = list(vocab)

        chars = len(vocab)

        for k, v in word_freqs.items():
            tokens.append(([c for c in k], v))

        rules = []
        while len(vocab) < vocab_size:
            pair = self.pair_freqs(tokens)
            rules.append(pair)

            vocab.add(f"{pair[0]}{pair[1]}")
            tokens = self.merge_tokens(tokens, pair)

        self.rules = rules
        self.tokens = tokens
        
        self.vocab = {vocab[i]: i for i in range(len(vocab))}
        self.vocab_id = {b:a for a,b in self.vocab.items()}

        self.pad_token_id = 0
        self.unk_token_id = 1

        self.pad_token = "<|pad|>"
        self.unk_token = "<|unk|>"

        self.max_len = max_len

    def __call__(
        self,
        texts,
        truncation=False,
        max_length=None,
        return_tensors=None,
        padding=None,
    ):
        text = " ".join(texts)
        text = [
            char
            for word, _ in self.pre_tokenizer.pre_tokenize_str(text)
            for char in word
        ]

        for rule in self.rules:
            newtext = []
            i = 0

            while i < len(text) - 1:
                if (text[i], text[i + 1]) == rule:
                    newtext.append(f"{rule[0]}{rule[1]}")
                    i += 1
                else:
                    newtext.append(text[i])
                i += 1

            if i == len(text) - 1:
                newtext.append(text[i])

            text = newtext

        n = self.max_len
        ids = [self.vocab[c] if c in self.vocab else 1 for c in text]
        
        ids = [
            ids[i : i + n]
            for i in range(0, len(ids), n)
            if not truncation or i + n <= len(ids)
        ]

        mask1 = [[1] * n] * (len(ids) - 1)
        mask2 = [[1] * (len(ids[-1])) + [0] * (n - len(ids[-1]))]
        mask = mask1 + mask2

        ids[-1] = ids[-1] + [0] * (n - len(ids[-1]))

        return {
            "input_ids": ids,
            "attention_mask": mask,
        }

In [None]:
dataset = DatasetDict({
    "train": load_dataset("huggingface-course/codeparrot-ds-train")["train"].shuffle(seed=42).select(range(NNN)),
    "valid": load_dataset("huggingface-course/codeparrot-ds-train")["train"].shuffle(seed=42).select(range(NNN)),
})

context_length = 128
corpset = list(dataset["valid"]["content"])

#tokenizer = AutoTokenizer.from_pretrained("gpt2")
#tokenizer.pad_token = tokenizer.eos_token

vs = 300
tokenizer = Tokenizer(corpset,vs,context_length)

In [None]:
%%time

def tokenize(element):
    outputs = tokenizer(
        element["content"],
        truncation=True,
        padding="max_length",
        max_length=context_length
    )

    return outputs

dataset_tokenized = dataset.map(tokenize, batched=True, remove_columns=dataset["train"].column_names)#, load_from_cache_file=False)

In [None]:
config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer.vocab),
    n_ctx=context_length,
    #bos_token_id=tokenizer.bos_token_id,
    #eos_token_id=tokenizer.eos_token_id,
)

In [None]:
bs = 8

dataset_tokenized.set_format("torch")
train_dataloader = DataLoader(dataset_tokenized["train"], batch_size=bs, shuffle=True)
eval_dataloader = DataLoader(dataset_tokenized["valid"], batch_size=bs)

In [None]:
weight_decay = 0.1

def get_grouped_params(model, no_decay=["bias", "LayerNorm.weight"]):
    params_with_wd, params_without_wd = [], []
    for n, p in model.named_parameters():
        if any(nd in n for nd in no_decay):
            params_without_wd.append(p)
        else:
            params_with_wd.append(p)
    return [
        {"params": params_with_wd, "weight_decay": weight_decay},
        {"params": params_without_wd, "weight_decay": 0.0},
    ]

In [None]:
def evaluate():
    model.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(batch["input_ids"], labels=batch["input_ids"])

        losses.append(accelerator.gather_for_metrics(outputs.loss))
        
    loss = torch.mean(torch.tensor(losses))
    
    try:
        perplexity = torch.exp(loss)
    except OverflowError:
        perplexity = float("inf")
    return loss.item(), perplexity.item()

In [None]:
model = GPT2LMHeadModel(config)

In [None]:
optimizer = AdamW(get_grouped_params(model), lr=5e-4)

In [None]:
accelerator = Accelerator(mixed_precision="fp16")

model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

In [None]:
num_train_epochs = epochsN
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch
gradient_accumulation_steps = 8
eval_steps = 50
completed_steps = 0

In [None]:
lr_sched = lr_scheduler.LinearLR(optimizer, total_iters=num_training_steps)

In [None]:
model.train()
loss_fct = CrossEntropyLoss()
progress = tqdm.tqdm(total=num_training_steps)

for epoch in range(num_train_epochs):
    for step, batch in enumerate(train_dataloader, start=1):        
        logits = model(batch["input_ids"]).logits
        
        shift_labels = batch["input_ids"][..., 1:].contiguous()
        shift_logits = logits[..., :-1, :].contiguous()
    
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        
        loss = loss / gradient_accumulation_steps
        """
        if step % 100 == 0:
            accelerator.print(
                {
                    "loss/train": loss.item() * gradient_accumulation_steps,
                }
            )
        """

    
        accelerator.backward(loss)

        del logits
        del shift_logits
        del shift_labels     

        
        if step % gradient_accumulation_steps == 0:
            accelerator.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_sched.step()
            optimizer.zero_grad()
            completed_steps += 1

        if (step % (eval_steps * gradient_accumulation_steps)) == 0:
            eval_loss, perplexity = evaluate()
            accelerator.print({"loss/eval": eval_loss, "perplexity": perplexity})
            model.train()
            accelerator.wait_for_everyone()

            
        progress.update(1)

In [None]:
text = """x = """
newt = []

for _ in range(20):
    #toks = tokenizer(text, return_tensors="pt")
    #toks.to("cuda")
    toks = tokenizer(text, None,None)
    toks = {a:torch.tensor(b).to("cuda") for a,b in toks.items()}
    with torch.no_grad():
        logits = model(**toks).logits
    best_id = torch.argmax(logits, dim=-1)[0][0].item()
    tok = tokenizer.convert_ids_to_tokens(best_id)
    newt.append(tok)

text = text  + tokenizer.convert_tokens_to_string(newt)
text