In [None]:
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig, pipeline
from torch.nn import CrossEntropyLoss
from torch.utils.data.dataloader import DataLoader
from torch.optim import AdamW, lr_scheduler
from accelerate import Accelerator
from typing import defaultdict
import torch
import tqdm
import re
import pandas as pd
from unidecode import unidecode
import string

In [None]:
type("lol"[1:3])

In [None]:
class Tokenizer:
    def get_pair_freqs(self, token_freqs):
        pair_freqs = defaultdict(int)
        
        for token,freq in token_freqs:
            for i in range(len(token) - 1):
                pair = token[i:i+2]
                pair_freqs[pair] += freq
                
        return sorted(pair_freqs.items(), key=lambda x: x[1])

    def merge_tokens(self, token_freqs, pair):
        token_freqs_new = []

        for token,freq in token_freqs: 
            w = []

            i = 0
            while i < len(token) - 1:
                pair_new = token[i:i+2]

                if pair == pair_new:
                    w.append(pair)
                    i += 2
                else:
                    w.append(token[i])
                    i += 1

            if i == len(token) - 1:
                w.append(token[i])

            token_freq = (w, freq)
            token_freqs_new.append(token_freq)

        return token_freqs_new

    def __len__(self):
        return len(vocab)

    def get_word_freqs(self, corpus):
        word_freqs = defaultdict(int)
        for text in corpus:
            for word in text.split():
                word_freqs[word] += 1

        return word_freqs

    def get_tokens(self, word_freqs):
        tokens = []
        for k, v in word_freqs.items():
            tokens.append((k, v))

        return tokens

    def make_rules(self, vocab, vocab_size, tokens):
        rules = []
        while len(vocab) < vocab_size:
            pair_freqs = self.get_pair_freqs(tokens)
            pair = pair_freqs[-1][0]
            rules.append(pair)
            vocab.add(f"{pair[0]}{pair[1]}")
            tokens = self.merge_tokens(tokens, pair)
        return vocab, tokens, rules
        
    def train_vocab(self, corpus, vocab_size):
        corpus = [unidecode(text).lower() for text in corpus]
        vocab = {"<|pad|>", "<|unk|>"} | set(string.printable.lower())
        word_freqs = self.get_word_freqs(corpus)            
        tokens = self.get_tokens(word_freqs)
        vocab, tokens, rules = self.make_rules(vocab,vocab_size,tokens)

        return vocab, tokens, rules
    
    def __init__(self, corpus, vocab_size, max_len):
        vocab, tokens, rules = self.train_vocab(corpus, vocab_size)

        self.rules = rules
        self.tokens = tokens
        self.vocab = {v:i for i,v in enumerate(vocab)}
        self.vocab_id = {b:a for a,b in self.vocab.items()}
        self.pad_token_id = 0
        self.unk_token_id = 1
        self.pad_token = "<|pad|>"
        self.unk_token = "<|unk|>"
        self.max_len = max_len

    def __call__(
        self,
        texts,
        truncation=False,
    ):
        text = [
            unidecode(char).lower()
            for text in texts
            for char in text
        ]
     

        for rule in self.rules:
            newtext = []
            i = 0

            while i < len(text) - 1:
                if (text[i], text[i + 1]) == rule:
                    newtext.append(f"{rule[0]}{rule[1]}")
                    i += 1
                else:
                    newtext.append(text[i])
                i += 1

            if i == len(text) - 1:
                newtext.append(text[i])

            text = newtext

        n = self.max_len
        ids = [self.vocab[c] if c in self.vocab else 1 for c in text]
        
        ids = [
            ids[i : i + n]
            for i in range(0, len(ids), n)
            if not truncation or i + n <= len(ids)
        ]

        mask1 = [[1] * n] * (len(ids) - 1)
        mask2 = [[1] * (len(ids[-1])) + [0] * (n - len(ids[-1]))]
        mask = mask1 + mask2

        ids[-1] = ids[-1] + [0] * (n - len(ids[-1]))

        return {
            "input_ids": ids,
            "attention_mask": mask,
        }

In [None]:
NNN = 10
epochsN = 3

dataset = DatasetDict({
    "train": load_dataset("huggingface-course/codeparrot-ds-train")["train"].shuffle(seed=42).select(range(NNN)),
    "valid": load_dataset("huggingface-course/codeparrot-ds-train")["train"].shuffle(seed=42).select(range(NNN)),
})

context_length = 128
corpset = list(dataset["valid"]["content"])

vs = 300
tokenizer = Tokenizer(corpset,vs,context_length)

In [None]:
tokenizer("this is a string")

In [None]:
tokenizer.vocab_id[33]

In [None]:
%%time

def tokenize(element):
    outputs = tokenizer(
        element["content"],
        truncation=True,
    )

    return outputs

dataset_tokenized = dataset.map(tokenize, batched=True, remove_columns=dataset["train"].column_names, load_from_cache_file=False)

In [None]:
config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer.vocab),
    n_ctx=context_length,
)

In [None]:
bs = 8

dataset_tokenized.set_format("torch")
train_dataloader = DataLoader(dataset_tokenized["train"], batch_size=bs, shuffle=True)
eval_dataloader = DataLoader(dataset_tokenized["valid"], batch_size=bs)

In [None]:
weight_decay = 0.1

def get_grouped_params(model, no_decay=["bias", "LayerNorm.weight"]):
    params_with_wd, params_without_wd = [], []
    for n, p in model.named_parameters():
        if any(nd in n for nd in no_decay):
            params_without_wd.append(p)
        else:
            params_with_wd.append(p)
    return [
        {"params": params_with_wd, "weight_decay": weight_decay},
        {"params": params_without_wd, "weight_decay": 0.0},
    ]

In [None]:
def evaluate():
    model.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(batch["input_ids"], labels=batch["input_ids"])

        losses.append(accelerator.gather_for_metrics(outputs.loss))
        
    loss = torch.mean(torch.tensor(losses))
    
    try:
        perplexity = torch.exp(loss)
    except OverflowError:
        perplexity = float("inf")
    return loss.item(), perplexity.item()

In [None]:
model = GPT2LMHeadModel(config)

In [None]:
optimizer = AdamW(get_grouped_params(model), lr=5e-4)

In [None]:
accelerator = Accelerator(mixed_precision="fp16")

model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

In [None]:
num_train_epochs = epochsN
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch
gradient_accumulation_steps = 8
eval_steps = 50
completed_steps = 0

In [None]:
lr_sched = lr_scheduler.LinearLR(optimizer, total_iters=num_training_steps)

In [None]:
model.train()
loss_fct = CrossEntropyLoss()
progress = tqdm.tqdm(total=num_training_steps)

for epoch in range(num_train_epochs):
    for step, batch in enumerate(train_dataloader, start=1):        
        logits = model(batch["input_ids"]).logits
        
        shift_labels = batch["input_ids"][..., 1:].contiguous()
        shift_logits = logits[..., :-1, :].contiguous()
    
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        
        loss = loss / gradient_accumulation_steps
        """
        if step % 100 == 0:
            accelerator.print(
                {
                    "loss/train": loss.item() * gradient_accumulation_steps,
                }
            )
        """

    
        accelerator.backward(loss)

        del logits
        del shift_logits
        del shift_labels     

        
        if step % gradient_accumulation_steps == 0:
            accelerator.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_sched.step()
            optimizer.zero_grad()
            completed_steps += 1

        if (step % (eval_steps * gradient_accumulation_steps)) == 0:
            eval_loss, perplexity = evaluate()
            accelerator.print({"loss/eval": eval_loss, "perplexity": perplexity})
            model.train()
            accelerator.wait_for_everyone()

            
        progress.update(1)

In [None]:
text = """x ="""
newt = []

for _ in range(20):
    toks = tokenizer(text)
    toks = {a:torch.tensor(b).to("cuda") for a,b in toks.items()}
    with torch.no_grad():
        logits = model(**toks).logits
    best_id = torch.argmax(logits, dim=-1)[0][0].item()
    print(best_id)
    tok = tokenizer.vocab_id[best_id]
    newt.append(tok)

text = text  + " ".join(newt)
text